diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000..d9e6504 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,109 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = migrations + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python>=3.9 or backports.zoneinfo library. +# Any required deps can installed by adding `alembic[tz]` to the pip requirements +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to migrations/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# set to 'true' to search source files recursively +# in "version_locations" directory +# New in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = postgresql+asyncpg://konstruct_app:konstruct_dev@localhost:5432/konstruct + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "ruff" - use the reformat_code hook, with the configuration file +# hooks = ruff +# ruff.type = exec +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = --line-length 120 + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/migrations/env.py b/migrations/env.py new file mode 100644 index 0000000..2cd8334 --- /dev/null +++ b/migrations/env.py @@ -0,0 +1,103 @@ +""" +Alembic migration environment — async SQLAlchemy configuration. + +Uses asyncpg driver with asyncio migration pattern required for SQLAlchemy 2.0. +Runs migrations as the postgres admin user (DATABASE_ADMIN_URL) so it can: + - CREATE ROLE konstruct_app + - ENABLE ROW LEVEL SECURITY + - FORCE ROW LEVEL SECURITY + - CREATE POLICY + +Application code always uses DATABASE_URL (konstruct_app role). +""" + +from __future__ import annotations + +import asyncio +import os +import sys +from logging.config import fileConfig + +from alembic import context +from sqlalchemy.ext.asyncio import create_async_engine + +# --------------------------------------------------------------------------- +# Make sure packages/shared is importable when running `alembic upgrade head` +# from the repo root. +# --------------------------------------------------------------------------- +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from shared.models.tenant import Base # noqa: E402 # type: ignore[import] + +# Import auth model to register it with Base.metadata +import shared.models.auth # noqa: E402, F401 # type: ignore[import] + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# --------------------------------------------------------------------------- +# Metadata for autogenerate support +# --------------------------------------------------------------------------- +target_metadata = Base.metadata + +# --------------------------------------------------------------------------- +# Use DATABASE_ADMIN_URL if set (for CI / production migrations), +# otherwise fall back to alembic.ini sqlalchemy.url. +# --------------------------------------------------------------------------- +database_url = os.environ.get("DATABASE_ADMIN_URL") or config.get_main_option("sqlalchemy.url") + + +def run_migrations_offline() -> None: + """ + Run migrations in 'offline' mode. + + This configures the context with just a URL and not an Engine. + Useful for generating SQL scripts without a live DB connection. + """ + context.configure( + url=database_url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +async def run_async_migrations() -> None: + """ + Create an async engine and run migrations within an async context. + + This is the required pattern for SQLAlchemy 2.0 + asyncpg. + """ + connectable = create_async_engine(database_url, echo=False) + + async with connectable.connect() as connection: + await connection.run_sync(do_run_migrations) + + await connectable.dispose() + + +def do_run_migrations(connection: object) -> None: + """Synchronous migration runner — called within async context.""" + context.configure(connection=connection, target_metadata=target_metadata) # type: ignore[arg-type] + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode with a live DB connection.""" + asyncio.run(run_async_migrations()) + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/migrations/script.py.mako b/migrations/script.py.mako new file mode 100644 index 0000000..ee746cf --- /dev/null +++ b/migrations/script.py.mako @@ -0,0 +1,28 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from __future__ import annotations + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/migrations/versions/001_initial_schema.py b/migrations/versions/001_initial_schema.py new file mode 100644 index 0000000..e455ca9 --- /dev/null +++ b/migrations/versions/001_initial_schema.py @@ -0,0 +1,195 @@ +"""Initial schema: tenants, agents, channel_connections, portal_users with RLS + +Revision ID: 001 +Revises: +Create Date: 2026-03-23 + +This migration: +1. Creates the konstruct_app application role (non-superuser) +2. Creates all four tables matching the SQLAlchemy models +3. Applies Row Level Security (RLS) with FORCE ROW LEVEL SECURITY to + tenant-scoped tables (agents, channel_connections) +4. Creates RLS policies that scope rows to app.current_tenant session variable +5. Grants appropriate permissions to konstruct_app role + +CRITICAL: FORCE ROW LEVEL SECURITY is applied to agents and channel_connections. +This means even the table owner cannot bypass RLS. The integration test +`test_tenant_isolation.py` must verify this is in effect. +""" + +from __future__ import annotations + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects.postgresql import UUID + +# revision identifiers, used by Alembic. +revision: str = "001" +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +# All valid channel types — kept in sync with ChannelType StrEnum in message.py +_CHANNEL_TYPES = ("slack", "whatsapp", "mattermost", "rocketchat", "teams", "telegram", "signal") + + +def upgrade() -> None: + # ------------------------------------------------------------------------- + # 1. Create application role (idempotent) + # ------------------------------------------------------------------------- + op.execute(""" + DO $$ + BEGIN + IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = 'konstruct_app') THEN + CREATE ROLE konstruct_app WITH LOGIN PASSWORD 'konstruct_dev'; + END IF; + END + $$ + """) + + op.execute("GRANT USAGE ON SCHEMA public TO konstruct_app") + + # ------------------------------------------------------------------------- + # 2. Create channel_type ENUM (using raw SQL to avoid SQLAlchemy auto-emit) + # We use op.execute with raw DDL so SQLAlchemy does NOT auto-emit + # a second CREATE TYPE statement in create_table below. + # ------------------------------------------------------------------------- + op.execute( + "CREATE TYPE channel_type_enum AS ENUM " + "('slack', 'whatsapp', 'mattermost', 'rocketchat', 'teams', 'telegram', 'signal')" + ) + + # ------------------------------------------------------------------------- + # 3. Create tenants table (no RLS — platform admin needs full visibility) + # ------------------------------------------------------------------------- + op.create_table( + "tenants", + sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")), + sa.Column("name", sa.String(255), nullable=False, unique=True), + sa.Column("slug", sa.String(100), nullable=False, unique=True), + sa.Column("settings", sa.JSON, nullable=False, server_default=sa.text("'{}'")), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("NOW()")), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.text("NOW()"), + ), + ) + op.create_index("ix_tenants_slug", "tenants", ["slug"]) + + # ------------------------------------------------------------------------- + # 4. Create agents table with RLS + # ------------------------------------------------------------------------- + op.create_table( + "agents", + sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")), + sa.Column("tenant_id", UUID(as_uuid=True), sa.ForeignKey("tenants.id", ondelete="CASCADE"), nullable=False), + sa.Column("name", sa.String(255), nullable=False), + sa.Column("role", sa.String(255), nullable=False), + sa.Column("persona", sa.Text, nullable=False, server_default=sa.text("''")), + sa.Column("system_prompt", sa.Text, nullable=False, server_default=sa.text("''")), + sa.Column("model_preference", sa.String(50), nullable=False, server_default=sa.text("'quality'")), + sa.Column("tool_assignments", sa.JSON, nullable=False, server_default=sa.text("'[]'")), + sa.Column("escalation_rules", sa.JSON, nullable=False, server_default=sa.text("'[]'")), + sa.Column("is_active", sa.Boolean, nullable=False, server_default=sa.text("TRUE")), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("NOW()")), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.text("NOW()"), + ), + ) + op.create_index("ix_agents_tenant_id", "agents", ["tenant_id"]) + + # Apply RLS to agents — FORCE ensures even table owner cannot bypass + op.execute("ALTER TABLE agents ENABLE ROW LEVEL SECURITY") + op.execute("ALTER TABLE agents FORCE ROW LEVEL SECURITY") + op.execute(""" + CREATE POLICY tenant_isolation ON agents + USING (tenant_id = current_setting('app.current_tenant', TRUE)::uuid) + """) + + # ------------------------------------------------------------------------- + # 5. Create channel_connections table with RLS + # Use sa.Text for channel_type column — cast to enum_type in app code. + # The channel_type_enum was created above via raw DDL. + # We reference it here using sa.text cast to avoid SQLAlchemy auto-emit. + # ------------------------------------------------------------------------- + op.create_table( + "channel_connections", + sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")), + sa.Column("tenant_id", UUID(as_uuid=True), sa.ForeignKey("tenants.id", ondelete="CASCADE"), nullable=False), + sa.Column( + "channel_type", + sa.Text, # Stored as text, constrained by CHECK to valid enum values + nullable=False, + ), + sa.Column("workspace_id", sa.String(255), nullable=False), + sa.Column("config", sa.JSON, nullable=False, server_default=sa.text("'{}'")), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("NOW()")), + sa.UniqueConstraint("channel_type", "workspace_id", name="uq_channel_workspace"), + ) + op.create_index("ix_channel_connections_tenant_id", "channel_connections", ["tenant_id"]) + + # Add CHECK constraint to enforce valid channel types (reuses enum values) + op.execute( + "ALTER TABLE channel_connections ADD CONSTRAINT chk_channel_type " + f"CHECK (channel_type IN {tuple(_CHANNEL_TYPES)})" + ) + + # Apply RLS to channel_connections + op.execute("ALTER TABLE channel_connections ENABLE ROW LEVEL SECURITY") + op.execute("ALTER TABLE channel_connections FORCE ROW LEVEL SECURITY") + op.execute(""" + CREATE POLICY tenant_isolation ON channel_connections + USING (tenant_id = current_setting('app.current_tenant', TRUE)::uuid) + """) + + # ------------------------------------------------------------------------- + # 6. Create portal_users table (no RLS — auth happens before tenant context) + # ------------------------------------------------------------------------- + op.create_table( + "portal_users", + sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.text("gen_random_uuid()")), + sa.Column("email", sa.String(255), nullable=False, unique=True), + sa.Column("hashed_password", sa.String(255), nullable=False), + sa.Column("name", sa.String(255), nullable=False), + sa.Column("is_admin", sa.Boolean, nullable=False, server_default=sa.text("FALSE")), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("NOW()")), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.text("NOW()"), + ), + ) + op.create_index("ix_portal_users_email", "portal_users", ["email"]) + + # ------------------------------------------------------------------------- + # 7. Grant table permissions to konstruct_app role + # ------------------------------------------------------------------------- + op.execute("GRANT SELECT, INSERT, UPDATE, DELETE ON tenants TO konstruct_app") + op.execute("GRANT SELECT, INSERT, UPDATE, DELETE ON agents TO konstruct_app") + op.execute("GRANT SELECT, INSERT, UPDATE, DELETE ON channel_connections TO konstruct_app") + op.execute("GRANT SELECT, INSERT, UPDATE, DELETE ON portal_users TO konstruct_app") + + +def downgrade() -> None: + # Revoke grants + op.execute("REVOKE ALL ON portal_users FROM konstruct_app") + op.execute("REVOKE ALL ON channel_connections FROM konstruct_app") + op.execute("REVOKE ALL ON agents FROM konstruct_app") + op.execute("REVOKE ALL ON tenants FROM konstruct_app") + + # Drop tables + op.drop_table("portal_users") + op.drop_table("channel_connections") + op.drop_table("agents") + op.drop_table("tenants") + + # Drop enum type + op.execute("DROP TYPE IF EXISTS channel_type_enum") diff --git a/packages/shared/shared/rls.py b/packages/shared/shared/rls.py index 2e1d50b..c13f31f 100644 --- a/packages/shared/shared/rls.py +++ b/packages/shared/shared/rls.py @@ -9,7 +9,7 @@ How it works: SET LOCAL app.current_tenant = '' into the current transaction. 4. PostgreSQL evaluates this setting in every RLS policy via: - current_setting('app.current_tenant')::uuid + current_setting('app.current_tenant', TRUE)::uuid CRITICAL: The application MUST connect as `konstruct_app` (not postgres superuser). Superuser connections bypass RLS entirely — isolation tests @@ -17,6 +17,12 @@ would pass trivially but provide zero real protection. IMPORTANT: SET LOCAL is transaction-scoped. The tenant context resets automatically when each transaction ends — no manual cleanup required. + +NOTE ON SQL INJECTION: PostgreSQL's SET LOCAL does not support parameterized +placeholders. We protect against injection by passing the tenant_id value +through uuid.UUID() — any non-UUID string raises ValueError before it reaches +the database. The resulting string is always in canonical UUID format: +xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx with only hex chars and hyphens. """ from __future__ import annotations @@ -33,6 +39,9 @@ from sqlalchemy.ext.asyncio import AsyncEngine # --------------------------------------------------------------------------- current_tenant_id: ContextVar[UUID | None] = ContextVar("current_tenant_id", default=None) +# Track engines that have already had the hook configured (by sync engine id) +_configured_engines: set[int] = set() + def configure_rls_hook(engine: AsyncEngine) -> None: """ @@ -46,6 +55,12 @@ def configure_rls_hook(engine: AsyncEngine) -> None: configure_rls_hook(engine) """ + # Idempotent — skip if already configured for this engine + engine_id = id(engine.sync_engine) + if engine_id in _configured_engines: + return + _configured_engines.add(engine_id) + @event.listens_for(engine.sync_engine, "before_cursor_execute") def _set_rls_tenant( conn: Any, @@ -58,10 +73,17 @@ def configure_rls_hook(engine: AsyncEngine) -> None: """ Inject SET LOCAL app.current_tenant before every statement. - Uses parameterized query to prevent SQL injection. + PostgreSQL SET LOCAL does not support parameterized placeholders. + We prevent SQL injection by validating the tenant_id value through + uuid.UUID() — any non-UUID string raises ValueError before it reaches + the database. The resulting string contains only hex characters and + hyphens in canonical UUID format. + SET LOCAL is transaction-scoped and resets on commit/rollback. """ tenant_id = current_tenant_id.get() if tenant_id is not None: - # Parameterized to prevent SQL injection — never use f-string here - cursor.execute("SET LOCAL app.current_tenant = %s", (str(tenant_id),)) + # Sanitize: round-trip through UUID raises ValueError on invalid input. + # UUID.__str__ always produces canonical xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx + safe_id = str(UUID(str(tenant_id))) + cursor.execute(f"SET LOCAL app.current_tenant = '{safe_id}'") diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..34efcd7 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,189 @@ +""" +Shared test fixtures for Konstruct. + +IMPORTANT: The `db_session` fixture connects as `konstruct_app` (not postgres +superuser). This is mandatory — RLS is bypassed for superuser connections, so +tests using superuser would pass trivially while providing zero real protection. + +Integration tests requiring a live PostgreSQL container are skipped if the +database is not available. Unit tests never require a live DB. + +Event loop design: All async fixtures use function scope to avoid pytest-asyncio +cross-loop-scope issues. The test database is created once (at session scope, via +a synchronous fixture) and reused across tests within the session. +""" + +from __future__ import annotations + +import asyncio +import os +import subprocess +import uuid +from collections.abc import AsyncGenerator +from typing import Any + +import pytest +import pytest_asyncio +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine + +# --------------------------------------------------------------------------- +# Database URLs +# --------------------------------------------------------------------------- +_ADMIN_URL = os.environ.get( + "DATABASE_ADMIN_URL", + "postgresql+asyncpg://postgres:postgres_dev@localhost:5432/konstruct", +) +_APP_URL = os.environ.get( + "DATABASE_URL", + "postgresql+asyncpg://konstruct_app:konstruct_dev@localhost:5432/konstruct", +) + + +def _replace_db_name(url: str, new_db: str) -> str: + """Replace database name in a SQLAlchemy URL string.""" + parts = url.rsplit("/", 1) + return f"{parts[0]}/{new_db}" + + +# --------------------------------------------------------------------------- +# Session-scoped synchronous setup — creates and migrates the test DB once +# --------------------------------------------------------------------------- + +@pytest.fixture(scope="session") +def test_db_name() -> str: + """Create a fresh test database, run migrations, return DB name.""" + db_name = f"konstruct_test_{uuid.uuid4().hex[:8]}" + admin_postgres_url = _replace_db_name(_ADMIN_URL, "postgres") + + # Check PostgreSQL reachability using synchronous driver + try: + import asyncio as _asyncio + + async def _check() -> None: + eng = create_async_engine(admin_postgres_url) + async with eng.connect() as conn: + await conn.execute(text("SELECT 1")) + await eng.dispose() + + _asyncio.run(_check()) + except Exception as exc: + pytest.skip(f"PostgreSQL not available: {exc}") + + # Create test database + async def _create_db() -> None: + eng = create_async_engine(admin_postgres_url, isolation_level="AUTOCOMMIT") + async with eng.connect() as conn: + await conn.execute(text(f'CREATE DATABASE "{db_name}"')) + await eng.dispose() + + asyncio.run(_create_db()) + + # Run Alembic migrations against test DB (subprocess — avoids loop conflicts) + admin_test_url = _replace_db_name(_ADMIN_URL, db_name) + result = subprocess.run( + ["uv", "run", "alembic", "upgrade", "head"], + env={**os.environ, "DATABASE_ADMIN_URL": admin_test_url}, + capture_output=True, + text=True, + cwd=os.path.join(os.path.dirname(__file__), ".."), + ) + if result.returncode != 0: + # Clean up on failure + async def _drop_db() -> None: + eng = create_async_engine(admin_postgres_url, isolation_level="AUTOCOMMIT") + async with eng.connect() as conn: + await conn.execute(text(f'DROP DATABASE IF EXISTS "{db_name}"')) + await eng.dispose() + + asyncio.run(_drop_db()) + pytest.fail(f"Alembic migration failed:\n{result.stdout}\n{result.stderr}") + + yield db_name + + # Teardown: drop test database + async def _cleanup() -> None: + eng = create_async_engine(admin_postgres_url, isolation_level="AUTOCOMMIT") + async with eng.connect() as conn: + await conn.execute( + text( + "SELECT pg_terminate_backend(pid) FROM pg_stat_activity " + "WHERE datname = :dbname AND pid <> pg_backend_pid()" + ), + {"dbname": db_name}, + ) + await conn.execute(text(f'DROP DATABASE IF EXISTS "{db_name}"')) + await eng.dispose() + + asyncio.run(_cleanup()) + + +@pytest_asyncio.fixture +async def db_engine(test_db_name: str) -> AsyncGenerator[AsyncEngine, None]: + """ + Function-scoped async engine connected as konstruct_app. + + Using konstruct_app role is critical — it enforces RLS. The postgres + superuser would bypass RLS and make isolation tests worthless. + """ + app_test_url = _replace_db_name(_APP_URL, test_db_name) + engine = create_async_engine(app_test_url, echo=False) + yield engine + await engine.dispose() + + +@pytest_asyncio.fixture +async def db_session(db_engine: AsyncEngine) -> AsyncGenerator[AsyncSession, None]: + """ + Function-scoped async session connected as konstruct_app. + + The RLS hook is configured on this engine so SET LOCAL statements are + injected before each query when current_tenant_id is set. + """ + from shared.rls import configure_rls_hook + + # Always configure — SQLAlchemy event.listens_for is idempotent per listener function + # when the same function object is registered; but since configure_rls_hook creates + # a new closure each call, wrap with a set to avoid duplicate listeners. + configure_rls_hook(db_engine) + + session_factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) + async with session_factory() as session: + yield session + await session.rollback() + + +@pytest_asyncio.fixture +async def tenant_a(db_session: AsyncSession) -> dict[str, Any]: + """Create Tenant A and return its data dict.""" + tenant_id = uuid.uuid4() + suffix = uuid.uuid4().hex[:6] + await db_session.execute( + text("INSERT INTO tenants (id, name, slug, settings) VALUES (:id, :name, :slug, :settings)"), + { + "id": str(tenant_id), + "name": f"Tenant Alpha {suffix}", + "slug": f"tenant-alpha-{suffix}", + "settings": "{}", + }, + ) + await db_session.commit() + return {"id": tenant_id} + + +@pytest_asyncio.fixture +async def tenant_b(db_session: AsyncSession) -> dict[str, Any]: + """Create Tenant B and return its data dict.""" + tenant_id = uuid.uuid4() + suffix = uuid.uuid4().hex[:6] + await db_session.execute( + text("INSERT INTO tenants (id, name, slug, settings) VALUES (:id, :name, :slug, :settings)"), + { + "id": str(tenant_id), + "name": f"Tenant Beta {suffix}", + "slug": f"tenant-beta-{suffix}", + "settings": "{}", + }, + ) + await db_session.commit() + return {"id": tenant_id} diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/test_tenant_isolation.py b/tests/integration/test_tenant_isolation.py new file mode 100644 index 0000000..40590e2 --- /dev/null +++ b/tests/integration/test_tenant_isolation.py @@ -0,0 +1,273 @@ +""" +Integration tests for PostgreSQL Row Level Security (RLS) tenant isolation. + +Tests TNNT-01: All tenant data is isolated via PostgreSQL Row Level Security. + +THIS IS THE MOST CRITICAL TEST IN PHASE 1. + +These tests prove that Tenant A cannot see Tenant B's data through PostgreSQL +RLS — not through application-layer filtering, but at the database level. + +Critical verification points: +1. tenant_b cannot see tenant_a's agents (even by primary key lookup) +2. tenant_a can see its own agents +3. FORCE ROW LEVEL SECURITY is active (relforcerowsecurity = TRUE) +4. Same isolation holds for channel_connections table +5. All tests connect as konstruct_app (not superuser) +""" + +from __future__ import annotations + +import uuid +from typing import Any + +import pytest +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession + +from shared.rls import current_tenant_id + + +@pytest.mark.asyncio +class TestAgentRLSIsolation: + """Prove that Agent rows are invisible across tenant boundaries.""" + + async def test_tenant_b_cannot_see_tenant_a_agent( + self, + db_session: AsyncSession, + tenant_a: dict[str, Any], + tenant_b: dict[str, Any], + ) -> None: + """ + Core RLS test: tenant_b must see ZERO rows from agents owned by tenant_a. + + If this test passes trivially (e.g., because the connection is superuser), + the test_rls_is_forced test below will catch it. + """ + agent_id = uuid.uuid4() + + # Create an agent for tenant_a (using superuser-like direct insert) + # We need to temporarily bypass RLS to seed data + token = current_tenant_id.set(tenant_a["id"]) + try: + await db_session.execute( + text( + "INSERT INTO agents (id, tenant_id, name, role) " + "VALUES (:id, :tenant_id, :name, :role)" + ), + { + "id": str(agent_id), + "tenant_id": str(tenant_a["id"]), + "name": "Alice", + "role": "Support Lead", + }, + ) + await db_session.commit() + finally: + current_tenant_id.reset(token) + + # Now set tenant_b context and try to query the agent + token = current_tenant_id.set(tenant_b["id"]) + try: + result = await db_session.execute(text("SELECT id FROM agents")) + rows = result.fetchall() + assert len(rows) == 0, ( + f"RLS FAILURE: tenant_b can see {len(rows)} agent(s) belonging to tenant_a. " + "This is a critical security violation — check that FORCE ROW LEVEL SECURITY " + "is applied and that the connection uses konstruct_app role (not superuser)." + ) + finally: + current_tenant_id.reset(token) + + async def test_tenant_a_can_see_own_agents( + self, + db_session: AsyncSession, + tenant_a: dict[str, Any], + tenant_b: dict[str, Any], + ) -> None: + """Tenant A must be able to see its own agents — RLS must not block legitimate access.""" + agent_id = uuid.uuid4() + + # Insert agent as tenant_a + token = current_tenant_id.set(tenant_a["id"]) + try: + await db_session.execute( + text( + "INSERT INTO agents (id, tenant_id, name, role) " + "VALUES (:id, :tenant_id, :name, :role)" + ), + { + "id": str(agent_id), + "tenant_id": str(tenant_a["id"]), + "name": "Bob", + "role": "Sales Rep", + }, + ) + await db_session.commit() + finally: + current_tenant_id.reset(token) + + # Query as tenant_a — must see the agent + token = current_tenant_id.set(tenant_a["id"]) + try: + result = await db_session.execute(text("SELECT id FROM agents WHERE id = :id"), {"id": str(agent_id)}) + rows = result.fetchall() + assert len(rows) == 1, ( + f"Tenant A cannot see its own agent. Found {len(rows)} rows. " + "RLS policy may be too restrictive." + ) + finally: + current_tenant_id.reset(token) + + +@pytest.mark.asyncio +class TestChannelConnectionRLSIsolation: + """Prove that ChannelConnection rows are invisible across tenant boundaries.""" + + async def test_tenant_b_cannot_see_tenant_a_channel_connection( + self, + db_session: AsyncSession, + tenant_a: dict[str, Any], + tenant_b: dict[str, Any], + ) -> None: + """tenant_b must see ZERO channel_connections owned by tenant_a.""" + conn_id = uuid.uuid4() + + # Create a channel connection for tenant_a + token = current_tenant_id.set(tenant_a["id"]) + try: + await db_session.execute( + text( + "INSERT INTO channel_connections (id, tenant_id, channel_type, workspace_id, config) " + "VALUES (:id, :tenant_id, 'slack', :workspace_id, :config)" + ), + { + "id": str(conn_id), + "tenant_id": str(tenant_a["id"]), + "workspace_id": "T-ALPHA-WORKSPACE", + "config": "{}", + }, + ) + await db_session.commit() + finally: + current_tenant_id.reset(token) + + # Query as tenant_b — must see zero rows + token = current_tenant_id.set(tenant_b["id"]) + try: + result = await db_session.execute(text("SELECT id FROM channel_connections")) + rows = result.fetchall() + assert len(rows) == 0, ( + f"RLS FAILURE: tenant_b can see {len(rows)} channel_connection(s) belonging to tenant_a." + ) + finally: + current_tenant_id.reset(token) + + async def test_tenant_a_can_see_own_channel_connections( + self, + db_session: AsyncSession, + tenant_a: dict[str, Any], + ) -> None: + """Tenant A must see its own channel connections.""" + conn_id = uuid.uuid4() + + token = current_tenant_id.set(tenant_a["id"]) + try: + await db_session.execute( + text( + "INSERT INTO channel_connections (id, tenant_id, channel_type, workspace_id, config) " + "VALUES (:id, :tenant_id, 'telegram', :workspace_id, :config)" + ), + { + "id": str(conn_id), + "tenant_id": str(tenant_a["id"]), + "workspace_id": "tg-alpha-chat", + "config": "{}", + }, + ) + await db_session.commit() + finally: + current_tenant_id.reset(token) + + # Query as tenant_a — must see the connection + token = current_tenant_id.set(tenant_a["id"]) + try: + result = await db_session.execute( + text("SELECT id FROM channel_connections WHERE id = :id"), + {"id": str(conn_id)}, + ) + rows = result.fetchall() + assert len(rows) == 1, f"Tenant A cannot see its own channel connection. Found {len(rows)} rows." + finally: + current_tenant_id.reset(token) + + +@pytest.mark.asyncio +class TestRLSPolicyConfiguration: + """Verify PostgreSQL RLS configuration is correct at the schema level.""" + + async def test_agents_force_row_level_security_is_active( + self, + db_session: AsyncSession, + tenant_a: dict[str, Any], + ) -> None: + """ + FORCE ROW LEVEL SECURITY must be TRUE for agents table. + + Without FORCE RLS, the table owner (postgres) bypasses RLS. + This would mean our isolation tests passed trivially and provide + zero real protection. + """ + # We need to query pg_class as superuser to check this + # Using the session (which is konstruct_app) — pg_class is readable by all + token = current_tenant_id.set(tenant_a["id"]) + try: + result = await db_session.execute( + text("SELECT relforcerowsecurity FROM pg_class WHERE relname = 'agents'") + ) + row = result.fetchone() + finally: + current_tenant_id.reset(token) + + assert row is not None, "agents table not found in pg_class" + assert row[0] is True, ( + "FORCE ROW LEVEL SECURITY is NOT active on agents table. " + "This is a critical security misconfiguration — the table owner " + "can bypass RLS and cross-tenant data leakage is possible." + ) + + async def test_channel_connections_force_row_level_security_is_active( + self, + db_session: AsyncSession, + tenant_a: dict[str, Any], + ) -> None: + """FORCE ROW LEVEL SECURITY must be TRUE for channel_connections table.""" + token = current_tenant_id.set(tenant_a["id"]) + try: + result = await db_session.execute( + text("SELECT relforcerowsecurity FROM pg_class WHERE relname = 'channel_connections'") + ) + row = result.fetchone() + finally: + current_tenant_id.reset(token) + + assert row is not None, "channel_connections table not found in pg_class" + assert row[0] is True, ( + "FORCE ROW LEVEL SECURITY is NOT active on channel_connections table." + ) + + async def test_tenants_table_exists_and_is_accessible( + self, + db_session: AsyncSession, + tenant_a: dict[str, Any], + ) -> None: + """ + tenants table must be accessible without tenant context. + + RLS is NOT applied to tenants — the Router needs to look up all tenants + during message routing, before tenant context is set. + """ + result = await db_session.execute(text("SELECT id, slug FROM tenants LIMIT 10")) + rows = result.fetchall() + # Should be accessible (no RLS) — we just care it doesn't raise + assert rows is not None diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/test_normalize.py b/tests/unit/test_normalize.py new file mode 100644 index 0000000..ec66df9 --- /dev/null +++ b/tests/unit/test_normalize.py @@ -0,0 +1,143 @@ +""" +Unit tests for KonstructMessage normalization from Slack event payloads. + +Tests CHAN-01: Channel Gateway normalizes messages from all channels into +unified KonstructMessage format. + +These tests exercise normalization logic without requiring a live database. +""" + +from __future__ import annotations + +from datetime import datetime, timezone + +import pytest + +from shared.models.message import ChannelType, KonstructMessage, MessageContent, SenderInfo + + +def make_slack_event() -> dict: + """Minimal valid Slack message event payload.""" + return { + "type": "message", + "channel": "C12345ABC", + "user": "U98765XYZ", + "text": "Hey @bot can you help me?", + "ts": "1711234567.123456", + "thread_ts": "1711234567.123456", + "team": "T11223344", + "blocks": [], + "event_ts": "1711234567.123456", + } + + +def normalize_slack_event(payload: dict) -> KonstructMessage: + """ + Minimal normalizer for test purposes. + + The real normalizer lives in packages/gateway/channels/slack.py. + This function mirrors the expected output for unit testing the model. + """ + ts = float(payload["ts"]) + timestamp = datetime.fromtimestamp(ts, tz=timezone.utc) + + return KonstructMessage( + channel=ChannelType.SLACK, + channel_metadata={ + "workspace_id": payload["team"], + "channel_id": payload["channel"], + "event_ts": payload["event_ts"], + }, + sender=SenderInfo( + user_id=payload["user"], + display_name=payload["user"], # Display name resolved later by Slack API + ), + content=MessageContent( + text=payload["text"], + mentions=[u for u in payload["text"].split() if u.startswith("@")], + ), + timestamp=timestamp, + thread_id=payload.get("thread_ts"), + reply_to=None if payload.get("thread_ts") == payload.get("ts") else payload.get("thread_ts"), + ) + + +class TestKonstructMessageNormalization: + """Tests for Slack event normalization to KonstructMessage.""" + + def test_channel_type_is_slack(self) -> None: + """ChannelType must be set to 'slack' for Slack events.""" + msg = normalize_slack_event(make_slack_event()) + assert msg.channel == ChannelType.SLACK + assert msg.channel == "slack" + + def test_sender_info_extracted(self) -> None: + """SenderInfo user_id must match Slack user field.""" + payload = make_slack_event() + msg = normalize_slack_event(payload) + assert msg.sender.user_id == "U98765XYZ" + assert msg.sender.is_bot is False + + def test_content_text_preserved(self) -> None: + """MessageContent.text must contain original Slack message text.""" + payload = make_slack_event() + msg = normalize_slack_event(payload) + assert msg.content.text == "Hey @bot can you help me?" + + def test_thread_id_from_thread_ts(self) -> None: + """thread_id must be populated from Slack's thread_ts field.""" + payload = make_slack_event() + payload["thread_ts"] = "1711234500.000001" # Different from ts — it's a reply + msg = normalize_slack_event(payload) + assert msg.thread_id == "1711234500.000001" + + def test_thread_id_none_when_no_thread(self) -> None: + """thread_id must be None when the message is not in a thread.""" + payload = make_slack_event() + del payload["thread_ts"] + msg = normalize_slack_event(payload) + assert msg.thread_id is None + + def test_channel_metadata_contains_workspace_id(self) -> None: + """channel_metadata must contain workspace_id (Slack team ID).""" + payload = make_slack_event() + msg = normalize_slack_event(payload) + assert "workspace_id" in msg.channel_metadata + assert msg.channel_metadata["workspace_id"] == "T11223344" + + def test_channel_metadata_contains_channel_id(self) -> None: + """channel_metadata must contain the Slack channel ID.""" + payload = make_slack_event() + msg = normalize_slack_event(payload) + assert msg.channel_metadata["channel_id"] == "C12345ABC" + + def test_tenant_id_is_none_before_resolution(self) -> None: + """tenant_id must be None immediately after normalization (Router populates it).""" + msg = normalize_slack_event(make_slack_event()) + assert msg.tenant_id is None + + def test_message_has_uuid_id(self) -> None: + """KonstructMessage must have a UUID id assigned at construction.""" + import uuid + + msg = normalize_slack_event(make_slack_event()) + # Should not raise + parsed = uuid.UUID(msg.id) + assert str(parsed) == msg.id + + def test_timestamp_is_utc_datetime(self) -> None: + """timestamp must be a timezone-aware datetime in UTC.""" + msg = normalize_slack_event(make_slack_event()) + assert msg.timestamp.tzinfo is not None + assert msg.timestamp.tzinfo == timezone.utc + + def test_pydantic_validation_rejects_invalid_channel(self) -> None: + """KonstructMessage must reject unknown ChannelType values.""" + with pytest.raises(Exception): # pydantic.ValidationError + KonstructMessage( + channel="fax_machine", # type: ignore[arg-type] + channel_metadata={}, + sender=SenderInfo(user_id="u1", display_name="User"), + content=MessageContent(text="hello"), + timestamp=datetime.now(timezone.utc), + ) diff --git a/tests/unit/test_redis_namespacing.py b/tests/unit/test_redis_namespacing.py new file mode 100644 index 0000000..c7f5658 --- /dev/null +++ b/tests/unit/test_redis_namespacing.py @@ -0,0 +1,140 @@ +""" +Unit tests for Redis key namespacing. + +Tests TNNT-03: Per-tenant Redis namespace isolation for cache and session state. + +Every key function must: +1. Prepend {tenant_id}: to every key +2. Never be callable without a tenant_id argument +3. Produce predictable, stable key formats +""" + +from __future__ import annotations + +import inspect + +import pytest + +from shared.redis_keys import ( + engaged_thread_key, + idempotency_key, + rate_limit_key, + session_key, +) + + +class TestRedisKeyFormats: + """Tests for correct key format from all key constructor functions.""" + + def test_rate_limit_key_format(self) -> None: + """rate_limit_key must return '{tenant_id}:ratelimit:{channel}'.""" + key = rate_limit_key("tenant-a", "slack") + assert key == "tenant-a:ratelimit:slack" + + def test_rate_limit_key_different_channel(self) -> None: + """rate_limit_key must work for any channel type string.""" + key = rate_limit_key("tenant-a", "whatsapp") + assert key == "tenant-a:ratelimit:whatsapp" + + def test_idempotency_key_format(self) -> None: + """idempotency_key must return '{tenant_id}:dedup:{message_id}'.""" + key = idempotency_key("tenant-a", "msg-123") + assert key == "tenant-a:dedup:msg-123" + + def test_session_key_format(self) -> None: + """session_key must return '{tenant_id}:session:{thread_id}'.""" + key = session_key("tenant-a", "thread-456") + assert key == "tenant-a:session:thread-456" + + def test_engaged_thread_key_format(self) -> None: + """engaged_thread_key must return '{tenant_id}:engaged:{thread_id}'.""" + key = engaged_thread_key("tenant-a", "T12345") + assert key == "tenant-a:engaged:T12345" + + +class TestTenantIsolation: + """Tests that all key functions produce distinct namespaces per tenant.""" + + def test_rate_limit_keys_are_tenant_scoped(self) -> None: + """Two tenants with the same channel must produce different keys.""" + key_a = rate_limit_key("tenant-a", "slack") + key_b = rate_limit_key("tenant-b", "slack") + assert key_a != key_b + assert key_a.startswith("tenant-a:") + assert key_b.startswith("tenant-b:") + + def test_idempotency_keys_are_tenant_scoped(self) -> None: + """Two tenants with the same message_id must produce different keys.""" + key_a = idempotency_key("tenant-a", "msg-999") + key_b = idempotency_key("tenant-b", "msg-999") + assert key_a != key_b + assert key_a.startswith("tenant-a:") + assert key_b.startswith("tenant-b:") + + def test_session_keys_are_tenant_scoped(self) -> None: + """Two tenants with the same thread_id must produce different keys.""" + key_a = session_key("tenant-a", "thread-1") + key_b = session_key("tenant-b", "thread-1") + assert key_a != key_b + + def test_engaged_thread_keys_are_tenant_scoped(self) -> None: + """Two tenants with same thread must produce different engaged keys.""" + key_a = engaged_thread_key("tenant-a", "thread-1") + key_b = engaged_thread_key("tenant-b", "thread-1") + assert key_a != key_b + + def test_all_keys_include_tenant_id_prefix(self) -> None: + """Every key function must produce a key starting with the tenant_id.""" + tenant_id = "my-tenant-uuid" + keys = [ + rate_limit_key(tenant_id, "slack"), + idempotency_key(tenant_id, "msg-1"), + session_key(tenant_id, "thread-1"), + engaged_thread_key(tenant_id, "thread-1"), + ] + for key in keys: + assert key.startswith(f"{tenant_id}:"), ( + f"Key {key!r} does not start with tenant_id prefix '{tenant_id}:'" + ) + + +class TestNoBarKeysIsPossible: + """Tests that prove no key function can be called without tenant_id.""" + + def test_rate_limit_key_requires_tenant_id(self) -> None: + """rate_limit_key signature requires tenant_id as first argument.""" + sig = inspect.signature(rate_limit_key) + params = list(sig.parameters.keys()) + assert params[0] == "tenant_id" + + def test_idempotency_key_requires_tenant_id(self) -> None: + """idempotency_key signature requires tenant_id as first argument.""" + sig = inspect.signature(idempotency_key) + params = list(sig.parameters.keys()) + assert params[0] == "tenant_id" + + def test_session_key_requires_tenant_id(self) -> None: + """session_key signature requires tenant_id as first argument.""" + sig = inspect.signature(session_key) + params = list(sig.parameters.keys()) + assert params[0] == "tenant_id" + + def test_engaged_thread_key_requires_tenant_id(self) -> None: + """engaged_thread_key signature requires tenant_id as first argument.""" + sig = inspect.signature(engaged_thread_key) + params = list(sig.parameters.keys()) + assert params[0] == "tenant_id" + + def test_calling_without_tenant_id_raises_type_error(self) -> None: + """Calling any key function with zero args must raise TypeError.""" + with pytest.raises(TypeError): + rate_limit_key() # type: ignore[call-arg] + + with pytest.raises(TypeError): + idempotency_key() # type: ignore[call-arg] + + with pytest.raises(TypeError): + session_key() # type: ignore[call-arg] + + with pytest.raises(TypeError): + engaged_thread_key() # type: ignore[call-arg] diff --git a/tests/unit/test_tenant_resolution.py b/tests/unit/test_tenant_resolution.py new file mode 100644 index 0000000..e209e0d --- /dev/null +++ b/tests/unit/test_tenant_resolution.py @@ -0,0 +1,158 @@ +""" +Unit tests for tenant resolution logic. + +Tests TNNT-02: Inbound messages are resolved to the correct tenant via +channel metadata. + +These tests verify the resolution logic in isolation — no live database needed. +The production resolver queries channel_connections; here we mock that lookup. +""" + +from __future__ import annotations + +import uuid +from typing import Optional + +import pytest + +from shared.models.message import ChannelType + + +# --------------------------------------------------------------------------- +# Minimal in-process tenant resolver for unit testing +# --------------------------------------------------------------------------- + +class ChannelConnectionRecord: + """Represents a row from the channel_connections table.""" + + def __init__(self, tenant_id: uuid.UUID, channel_type: ChannelType, workspace_id: str) -> None: + self.tenant_id = tenant_id + self.channel_type = channel_type + self.workspace_id = workspace_id + + +def resolve_tenant( + workspace_id: str, + channel_type: ChannelType, + connections: list[ChannelConnectionRecord], +) -> Optional[uuid.UUID]: + """ + Resolve a (workspace_id, channel_type) pair to a tenant_id. + + This mirrors the logic in packages/router/tenant.py. + Returns None if no matching connection is found. + """ + for conn in connections: + if conn.workspace_id == workspace_id and conn.channel_type == channel_type: + return conn.tenant_id + return None + + +# --------------------------------------------------------------------------- +# Test fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def tenant_a_id() -> uuid.UUID: + return uuid.UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa") + + +@pytest.fixture +def tenant_b_id() -> uuid.UUID: + return uuid.UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb") + + +@pytest.fixture +def connections(tenant_a_id: uuid.UUID, tenant_b_id: uuid.UUID) -> list[ChannelConnectionRecord]: + return [ + ChannelConnectionRecord(tenant_a_id, ChannelType.SLACK, "T-WORKSPACE-A"), + ChannelConnectionRecord(tenant_b_id, ChannelType.SLACK, "T-WORKSPACE-B"), + ChannelConnectionRecord(tenant_b_id, ChannelType.TELEGRAM, "tg-chat-12345"), + ] + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestTenantResolution: + """Tests for tenant resolution from channel workspace IDs.""" + + def test_slack_workspace_resolves_to_correct_tenant( + self, + connections: list[ChannelConnectionRecord], + tenant_a_id: uuid.UUID, + ) -> None: + """Known Slack workspace_id must resolve to the correct tenant.""" + result = resolve_tenant("T-WORKSPACE-A", ChannelType.SLACK, connections) + assert result == tenant_a_id + + def test_second_slack_workspace_resolves_independently( + self, + connections: list[ChannelConnectionRecord], + tenant_b_id: uuid.UUID, + ) -> None: + """Two different Slack workspaces must resolve to their respective tenants.""" + result = resolve_tenant("T-WORKSPACE-B", ChannelType.SLACK, connections) + assert result == tenant_b_id + + def test_unknown_workspace_id_returns_none( + self, + connections: list[ChannelConnectionRecord], + ) -> None: + """Unknown workspace_id must return None — not raise, not return wrong tenant.""" + result = resolve_tenant("T-UNKNOWN", ChannelType.SLACK, connections) + assert result is None + + def test_wrong_channel_type_does_not_match( + self, + connections: list[ChannelConnectionRecord], + ) -> None: + """Workspace ID from wrong channel type must not match.""" + # T-WORKSPACE-A is registered as SLACK — should not match TELEGRAM + result = resolve_tenant("T-WORKSPACE-A", ChannelType.TELEGRAM, connections) + assert result is None + + def test_telegram_workspace_resolves_correctly( + self, + connections: list[ChannelConnectionRecord], + tenant_b_id: uuid.UUID, + ) -> None: + """Telegram channel connections resolve independently from Slack.""" + result = resolve_tenant("tg-chat-12345", ChannelType.TELEGRAM, connections) + assert result == tenant_b_id + + def test_empty_connections_returns_none(self) -> None: + """Empty connection list must return None for any workspace.""" + result = resolve_tenant("T-ANY", ChannelType.SLACK, []) + assert result is None + + def test_resolution_is_channel_type_specific( + self, + connections: list[ChannelConnectionRecord], + ) -> None: + """ + The same workspace_id string registered on two different channel types + must only match the correct channel type. + + This prevents a Slack workspace ID from accidentally matching a + Mattermost workspace with the same string value. + """ + same_id_connections = [ + ChannelConnectionRecord( + uuid.UUID("cccccccc-cccc-cccc-cccc-cccccccccccc"), + ChannelType.SLACK, + "SHARED-ID", + ), + ChannelConnectionRecord( + uuid.UUID("dddddddd-dddd-dddd-dddd-dddddddddddd"), + ChannelType.MATTERMOST, + "SHARED-ID", + ), + ] + slack_tenant = resolve_tenant("SHARED-ID", ChannelType.SLACK, same_id_connections) + mm_tenant = resolve_tenant("SHARED-ID", ChannelType.MATTERMOST, same_id_connections) + + assert slack_tenant != mm_tenant + assert str(slack_tenant) == "cccccccc-cccc-cccc-cccc-cccccccccccc" + assert str(mm_tenant) == "dddddddd-dddd-dddd-dddd-dddddddddddd"