diff --git a/tests/integration/test_agent_persona.py b/tests/integration/test_agent_persona.py new file mode 100644 index 0000000..732c0ca --- /dev/null +++ b/tests/integration/test_agent_persona.py @@ -0,0 +1,307 @@ +""" +Integration tests for agent persona reflection in LLM system prompts (AGNT-01). + +Tests verify: + 1. The system prompt contains the agent's name, role, and persona + 2. The AI transparency clause is always present + 3. model_preference from the agent config is passed to the LLM pool + 4. The full message array (system + user) is correctly structured + +These tests mock the LLM pool HTTP call — no real LLM API keys required. +They test the orchestrator -> agent builder -> runner chain in isolation. +""" + +from __future__ import annotations + +import uuid +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from orchestrator.agents.builder import build_messages, build_system_prompt + + +class _MockAgent: + """Minimal mock of the Agent ORM model for unit testing the builder.""" + + def __init__( + self, + name: str, + role: str, + persona: str, + system_prompt: str = "", + model_preference: str = "quality", + ) -> None: + self.id = uuid.uuid4() + self.tenant_id = uuid.uuid4() + self.name = name + self.role = role + self.persona = persona + self.system_prompt = system_prompt + self.model_preference = model_preference + self.is_active = True + + +class TestAgentPersonaInSystemPrompt: + """AGNT-01: Agent identity and persona must appear in the system prompt.""" + + def test_agent_name_in_system_prompt(self) -> None: + """System prompt must contain 'Your name is {agent.name}'.""" + agent = _MockAgent(name="Mara", role="Customer Support", persona="Professional and empathetic") + prompt = build_system_prompt(agent) + assert "Mara" in prompt + assert "Your name is Mara" in prompt + + def test_agent_role_in_system_prompt(self) -> None: + """System prompt must contain the agent's role.""" + agent = _MockAgent(name="Mara", role="Customer Support", persona="Professional and empathetic") + prompt = build_system_prompt(agent) + assert "Customer Support" in prompt + assert "Your role is Customer Support" in prompt + + def test_agent_persona_in_system_prompt(self) -> None: + """System prompt must include the agent's persona text.""" + agent = _MockAgent( + name="Mara", + role="Customer Support", + persona="Professional and empathetic", + ) + prompt = build_system_prompt(agent) + assert "Professional and empathetic" in prompt + + def test_ai_transparency_clause_always_present(self) -> None: + """ + The AI transparency clause must be present in every system prompt, + regardless of agent configuration. + + Agents must acknowledge they are AIs when directly asked. + """ + agent = _MockAgent(name="Mara", role="Support", persona="") + prompt = build_system_prompt(agent) + # The clause uses the word "AI" — verify it's unconditionally injected + assert "AI" in prompt or "artificial intelligence" in prompt.lower() + # Verify the specific phrase from builder.py + assert "you are an AI" in prompt.lower() or "you are an ai" in prompt.lower() + + def test_ai_transparency_present_even_with_empty_persona(self) -> None: + """Transparency clause must appear even when persona is empty.""" + agent = _MockAgent(name="Bot", role="Assistant", persona="") + prompt = build_system_prompt(agent) + assert "AI" in prompt + + def test_custom_system_prompt_included(self) -> None: + """If agent has a base system_prompt, it must appear in the output.""" + agent = _MockAgent( + name="Mara", + role="Support", + persona="Helpful", + system_prompt="Always be concise.", + ) + prompt = build_system_prompt(agent) + assert "Always be concise." in prompt + + def test_full_persona_customer_support_scenario(self) -> None: + """ + Full system prompt for a 'Mara' customer support agent must contain + all required elements. + """ + agent = _MockAgent( + name="Mara", + role="Customer Support", + persona="Professional, empathetic, solution-oriented.", + ) + prompt = build_system_prompt(agent) + + assert "Mara" in prompt + assert "Customer Support" in prompt + assert "Professional, empathetic, solution-oriented." in prompt + assert "AI" in prompt # Transparency clause + + def test_name_and_role_on_same_line(self) -> None: + """Name and role must appear together in the identity sentence.""" + agent = _MockAgent(name="Atlas", role="DevOps Engineer", persona="") + prompt = build_system_prompt(agent) + assert "Your name is Atlas. Your role is DevOps Engineer." in prompt + + +class TestAgentPersonaInMessages: + """Verify the full messages array structure passed to the LLM pool.""" + + def test_messages_has_system_message_first(self) -> None: + """The first message must be the system message.""" + agent = _MockAgent(name="Mara", role="Support", persona="Helpful") + prompt = build_system_prompt(agent) + messages = build_messages(system_prompt=prompt, user_message="Hello") + assert messages[0]["role"] == "system" + assert messages[0]["content"] == prompt + + def test_messages_has_user_message_last(self) -> None: + """The last message must be the user message.""" + agent = _MockAgent(name="Mara", role="Support", persona="Helpful") + prompt = build_system_prompt(agent) + user_text = "Can you help with my order?" + messages = build_messages(system_prompt=prompt, user_message=user_text) + assert messages[-1]["role"] == "user" + assert messages[-1]["content"] == user_text + + def test_messages_has_exactly_two_entries_no_history(self) -> None: + """Without history, messages must have exactly [system, user].""" + agent = _MockAgent(name="Mara", role="Support", persona="Helpful") + prompt = build_system_prompt(agent) + messages = build_messages(system_prompt=prompt, user_message="Hi") + assert len(messages) == 2 + + def test_messages_includes_history_in_order(self) -> None: + """Conversation history must appear between system and user messages.""" + agent = _MockAgent(name="Mara", role="Support", persona="Helpful") + prompt = build_system_prompt(agent) + history = [ + {"role": "user", "content": "Previous question"}, + {"role": "assistant", "content": "Previous answer"}, + ] + messages = build_messages(system_prompt=prompt, user_message="Follow-up", history=history) + # Structure: system, history[0], history[1], user + assert len(messages) == 4 + assert messages[1] == history[0] + assert messages[2] == history[1] + assert messages[-1]["role"] == "user" + + +class TestModelPreferencePassthrough: + """Verify model_preference is passed correctly to the LLM pool.""" + + async def test_model_preference_passed_to_llm_pool(self) -> None: + """ + The agent's model_preference must be forwarded as the 'model' field + in the LLM pool /complete request payload. + """ + from orchestrator.agents.runner import run_agent + from shared.models.message import ChannelType, KonstructMessage, MessageContent, SenderInfo + from datetime import datetime, timezone + + agent = _MockAgent( + name="Mara", + role="Customer Support", + persona="Professional and empathetic", + model_preference="quality", + ) + + msg = KonstructMessage( + tenant_id=str(agent.tenant_id), + channel=ChannelType.SLACK, + channel_metadata={"workspace_id": "T-TEST"}, + sender=SenderInfo(user_id="U1", display_name="Test User"), + content=MessageContent(text="Hello Mara"), + timestamp=datetime.now(tz=timezone.utc), + ) + + captured_payloads: list[dict] = [] + + async def mock_post_response(*args, **kwargs): + payload = kwargs.get("json", {}) + captured_payloads.append(payload) + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {"content": "Hello from Mara!", "model": "quality"} + return mock_resp + + with patch("httpx.AsyncClient") as mock_http_class: + mock_http_instance = AsyncMock() + mock_http_instance.__aenter__ = AsyncMock(return_value=mock_http_instance) + mock_http_instance.__aexit__ = AsyncMock(return_value=False) + mock_http_instance.post = AsyncMock(side_effect=mock_post_response) + mock_http_class.return_value = mock_http_instance + + result = await run_agent(msg, agent) + + assert len(captured_payloads) == 1 + payload = captured_payloads[0] + assert payload["model"] == "quality" + + async def test_llm_response_returned_as_string(self) -> None: + """run_agent must return the LLM response as a plain string.""" + from orchestrator.agents.runner import run_agent + from shared.models.message import ChannelType, KonstructMessage, MessageContent, SenderInfo + from datetime import datetime, timezone + + agent = _MockAgent( + name="Mara", + role="Support", + persona="Helpful", + model_preference="fast", + ) + + msg = KonstructMessage( + tenant_id=str(agent.tenant_id), + channel=ChannelType.SLACK, + channel_metadata={}, + sender=SenderInfo(user_id="U1", display_name="Test"), + content=MessageContent(text="What is 2+2?"), + timestamp=datetime.now(tz=timezone.utc), + ) + + with patch("httpx.AsyncClient") as mock_http_class: + mock_http_instance = AsyncMock() + mock_http_instance.__aenter__ = AsyncMock(return_value=mock_http_instance) + mock_http_instance.__aexit__ = AsyncMock(return_value=False) + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"content": "The answer is 4.", "model": "fast"} + mock_http_instance.post = AsyncMock(return_value=mock_response) + mock_http_class.return_value = mock_http_instance + + result = await run_agent(msg, agent) + + assert isinstance(result, str) + assert result == "The answer is 4." + + async def test_system_prompt_forwarded_to_llm_pool(self) -> None: + """ + The system prompt (including persona + AI clause) must be the first + message in the array sent to the LLM pool. + """ + from orchestrator.agents.runner import run_agent + from shared.models.message import ChannelType, KonstructMessage, MessageContent, SenderInfo + from datetime import datetime, timezone + + agent = _MockAgent( + name="Mara", + role="Customer Support", + persona="Professional and empathetic", + ) + + msg = KonstructMessage( + tenant_id=str(agent.tenant_id), + channel=ChannelType.SLACK, + channel_metadata={}, + sender=SenderInfo(user_id="U1", display_name="Test"), + content=MessageContent(text="hi"), + timestamp=datetime.now(tz=timezone.utc), + ) + + captured_messages: list = [] + + async def capture_request(*args, **kwargs): + captured_messages.extend(kwargs.get("json", {}).get("messages", [])) + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {"content": "Hello!", "model": "quality"} + return mock_resp + + with patch("httpx.AsyncClient") as mock_http_class: + mock_http_instance = AsyncMock() + mock_http_instance.__aenter__ = AsyncMock(return_value=mock_http_instance) + mock_http_instance.__aexit__ = AsyncMock(return_value=False) + mock_http_instance.post = AsyncMock(side_effect=capture_request) + mock_http_class.return_value = mock_http_instance + + await run_agent(msg, agent) + + assert len(captured_messages) >= 2 + system_msg = captured_messages[0] + assert system_msg["role"] == "system" + # System prompt must contain all persona elements + assert "Mara" in system_msg["content"] + assert "Customer Support" in system_msg["content"] + assert "Professional and empathetic" in system_msg["content"] + assert "AI" in system_msg["content"] # Transparency clause diff --git a/tests/integration/test_ratelimit.py b/tests/integration/test_ratelimit.py new file mode 100644 index 0000000..d1fc3c2 --- /dev/null +++ b/tests/integration/test_ratelimit.py @@ -0,0 +1,201 @@ +""" +Integration tests for rate limiting in the Slack event flow (CHAN-05). + +Tests verify: + 1. Over-limit Slack events result in an ephemeral "too many requests" message + being posted via the Slack client + 2. Over-limit events do NOT dispatch to Celery + +These tests exercise the full gateway handler code path with: + - fakeredis for rate limit state + - mocked Slack client (no real Slack workspace) + - mocked Celery task (no real Celery broker) +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import fakeredis +import pytest +import pytest_asyncio + +from gateway.channels.slack import _handle_slack_event, check_rate_limit + + +@pytest_asyncio.fixture +async def fake_redis(): + """Fake async Redis for rate limit state.""" + r = fakeredis.aioredis.FakeRedis(decode_responses=True) + yield r + await r.aclose() + + +def _make_slack_event(user_id: str = "U12345", channel: str = "C99999") -> dict: + """Minimal Slack app_mention event payload.""" + return { + "type": "app_mention", + "user": user_id, + "text": "<@UBOT123> hello", + "ts": "1711234567.000100", + "channel": channel, + "channel_type": "channel", + "_workspace_id": "T-WORKSPACE-X", + "_bot_user_id": "UBOT123", + } + + +def _make_mock_session_factory(tenant_id: str = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"): + """Build a mock session factory that returns the given tenant_id.""" + mock_session = AsyncMock() + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + # resolve_tenant will call session methods — we patch at the function level + mock_factory = MagicMock() + mock_factory.return_value = mock_session + return mock_factory + + +class TestRateLimitIntegration: + """CHAN-05: Integration tests for rate limit behavior in Slack handler.""" + + async def test_over_limit_sends_ephemeral_rejection(self, fake_redis) -> None: + """ + When rate limit is exceeded, an ephemeral 'too many requests' message + must be posted to the user, not dispatched to Celery. + """ + tenant_id = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + mock_client = AsyncMock() + mock_client.chat_postEphemeral = AsyncMock(return_value={"ok": True}) + mock_client.chat_postMessage = AsyncMock(return_value={"ts": "999.000", "ok": True}) + mock_say = AsyncMock() + + # Exhaust the rate limit + for _ in range(30): + await check_rate_limit(tenant_id, "slack", fake_redis, limit=30) + + event = _make_slack_event() + + with ( + patch("gateway.channels.slack.resolve_tenant", new_callable=AsyncMock, return_value=tenant_id), + patch("orchestrator.tasks.handle_message") as mock_celery_task, + ): + await _handle_slack_event( + event=event, + say=mock_say, + client=mock_client, + redis=fake_redis, + get_session=_make_mock_session_factory(tenant_id), + event_type="app_mention", + ) + + # Ephemeral rejection must be sent + mock_client.chat_postEphemeral.assert_called_once() + call_kwargs = mock_client.chat_postEphemeral.call_args + text = call_kwargs.kwargs.get("text", "") + assert "too many requests" in text.lower() or "please try again" in text.lower() + + # Celery task must NOT be dispatched + mock_celery_task.delay.assert_not_called() + + async def test_over_limit_does_not_post_placeholder(self, fake_redis) -> None: + """ + When rate limited, no 'Thinking...' placeholder message should be posted + (the request is rejected before reaching the placeholder step). + """ + tenant_id = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + mock_client = AsyncMock() + mock_client.chat_postEphemeral = AsyncMock(return_value={"ok": True}) + mock_client.chat_postMessage = AsyncMock(return_value={"ts": "999.000", "ok": True}) + mock_say = AsyncMock() + + # Exhaust the rate limit + for _ in range(30): + await check_rate_limit(tenant_id, "slack", fake_redis, limit=30) + + event = _make_slack_event() + + with ( + patch("gateway.channels.slack.resolve_tenant", new_callable=AsyncMock, return_value=tenant_id), + patch("orchestrator.tasks.handle_message"), + ): + await _handle_slack_event( + event=event, + say=mock_say, + client=mock_client, + redis=fake_redis, + get_session=_make_mock_session_factory(tenant_id), + event_type="app_mention", + ) + + # Placeholder message must NOT be posted + mock_client.chat_postMessage.assert_not_called() + + async def test_within_limit_dispatches_to_celery(self, fake_redis) -> None: + """ + Requests within the rate limit must dispatch to Celery (not rejected). + """ + tenant_id = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + mock_client = AsyncMock() + mock_client.chat_postMessage = AsyncMock(return_value={"ts": "999.001", "ok": True}) + mock_say = AsyncMock() + + event = _make_slack_event() + + mock_task = MagicMock() + mock_task.delay = MagicMock() + + with ( + patch("gateway.channels.slack.resolve_tenant", new_callable=AsyncMock, return_value=tenant_id), + patch("router.idempotency.is_duplicate", new_callable=AsyncMock, return_value=False), + patch("gateway.channels.slack._mark_thread_engaged", new_callable=AsyncMock), + patch("gateway.channels.slack.handle_message_task", mock_task, create=True), + ): + # Patch the import inside the function + with patch("orchestrator.tasks.handle_message") as celery_mock: + celery_mock.delay = MagicMock() + await _handle_slack_event( + event=event, + say=mock_say, + client=mock_client, + redis=fake_redis, + get_session=_make_mock_session_factory(tenant_id), + event_type="app_mention", + ) + + # Ephemeral rejection must NOT be sent + mock_client.chat_postEphemeral.assert_not_called() + + async def test_ephemeral_message_includes_retry_hint(self, fake_redis) -> None: + """ + The ephemeral rate limit rejection must mention when to retry. + """ + tenant_id = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + mock_client = AsyncMock() + mock_client.chat_postEphemeral = AsyncMock(return_value={"ok": True}) + mock_say = AsyncMock() + + # Exhaust the rate limit + for _ in range(30): + await check_rate_limit(tenant_id, "slack", fake_redis, limit=30) + + event = _make_slack_event() + + with ( + patch("gateway.channels.slack.resolve_tenant", new_callable=AsyncMock, return_value=tenant_id), + patch("orchestrator.tasks.handle_message"), + ): + await _handle_slack_event( + event=event, + say=mock_say, + client=mock_client, + redis=fake_redis, + get_session=_make_mock_session_factory(tenant_id), + event_type="app_mention", + ) + + call_kwargs = mock_client.chat_postEphemeral.call_args + text = call_kwargs.kwargs.get("text", "") + # Message should give actionable guidance ("try again", "seconds", etc.) + assert any(word in text.lower() for word in ["again", "second", "moment"]) diff --git a/tests/integration/test_slack_flow.py b/tests/integration/test_slack_flow.py new file mode 100644 index 0000000..4195c99 --- /dev/null +++ b/tests/integration/test_slack_flow.py @@ -0,0 +1,462 @@ +""" +Integration tests for the end-to-end Slack event flow (CHAN-02). + +Tests verify: + 1. app_mention event -> normalize -> tenant resolve -> Celery dispatch -> LLM -> thread reply + 2. DM (message with channel_type="im") follows the same pipeline + 3. Placeholder "Thinking..." is posted before Celery dispatch + 4. Placeholder is replaced with real response (via chat.update in orchestrator task) + 5. Bot messages are ignored (no infinite response loop) + 6. Unknown workspace_id events are silently ignored + +All tests mock the Slack client and Celery task — no live Slack workspace or +Celery broker required. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import fakeredis +import pytest +import pytest_asyncio + +from gateway.channels.slack import _handle_slack_event +from gateway.normalize import normalize_slack_event +from shared.models.message import ChannelType + + +@pytest_asyncio.fixture +async def fake_redis(): + """Fake async Redis.""" + r = fakeredis.aioredis.FakeRedis(decode_responses=True) + yield r + await r.aclose() + + +TENANT_ID = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" +WORKSPACE_ID = "T-WORKSPACE-TEST" + + +def _make_mention_event( + user: str = "U12345", + text: str = "<@UBOT123> can you help me?", + channel: str = "C99999", +) -> dict: + """Build a Slack app_mention event.""" + return { + "type": "app_mention", + "user": user, + "text": text, + "ts": "1711234567.000100", + "channel": channel, + "channel_type": "channel", + "_workspace_id": WORKSPACE_ID, + "_bot_user_id": "UBOT123", + } + + +def _make_dm_event( + user: str = "U12345", + text: str = "help me please", + channel: str = "D11111", +) -> dict: + """Build a Slack DM (im) event.""" + return { + "type": "message", + "user": user, + "text": text, + "ts": "1711234567.000200", + "channel": channel, + "channel_type": "im", + "_workspace_id": WORKSPACE_ID, + "_bot_user_id": "UBOT123", + } + + +def _make_bot_event() -> dict: + """A Slack bot_message event (must be ignored).""" + return { + "type": "message", + "bot_id": "B11111", + "subtype": "bot_message", + "text": "I replied to something", + "ts": "1711234567.000300", + "channel": "C99999", + "_workspace_id": WORKSPACE_ID, + "_bot_user_id": "UBOT123", + } + + +class TestNormalization: + """normalize_slack_event unit coverage for CHAN-02 gateway code path.""" + + def test_mention_text_strips_bot_token(self) -> None: + """Bot @mention token must be stripped from text before sending to agent.""" + event = {"user": "U1", "text": "<@UBOT123> hello there", "ts": "123.456"} + msg = normalize_slack_event(event, workspace_id="T-WS", bot_user_id="UBOT123") + assert msg.content.text == "hello there" + assert "<@UBOT123>" not in msg.content.text + + def test_channel_type_is_slack(self) -> None: + """Normalized message must have channel=SLACK.""" + event = {"user": "U1", "text": "hi", "ts": "123.456"} + msg = normalize_slack_event(event, workspace_id="T-WS") + assert msg.channel == ChannelType.SLACK + + def test_tenant_id_none_after_normalization(self) -> None: + """tenant_id must be None — Router populates it, not the normalizer.""" + event = {"user": "U1", "text": "hi", "ts": "123.456"} + msg = normalize_slack_event(event, workspace_id="T-WS") + assert msg.tenant_id is None + + def test_thread_id_set_from_thread_ts(self) -> None: + """thread_id must be set from thread_ts when present.""" + event = { + "user": "U1", + "text": "reply", + "ts": "123.999", + "thread_ts": "123.000", + } + msg = normalize_slack_event(event, workspace_id="T-WS") + assert msg.thread_id == "123.000" + + def test_workspace_id_in_channel_metadata(self) -> None: + """workspace_id must be stored in channel_metadata.""" + event = {"user": "U1", "text": "hi", "ts": "123.456"} + msg = normalize_slack_event(event, workspace_id="T-WORKSPACE-X") + assert msg.channel_metadata["workspace_id"] == "T-WORKSPACE-X" + + +class TestSlackMentionFlow: + """CHAN-02: End-to-end app_mention event pipeline.""" + + async def test_mention_posts_thinking_placeholder(self, fake_redis) -> None: + """ + When a valid @mention arrives, a 'Thinking...' placeholder must be + posted in-thread before Celery dispatch. + """ + mock_client = AsyncMock() + mock_client.chat_postMessage = AsyncMock(return_value={"ts": "999.001", "ok": True}) + mock_say = AsyncMock() + + event = _make_mention_event() + + with ( + patch("gateway.channels.slack.resolve_tenant", new_callable=AsyncMock, return_value=TENANT_ID), + patch("gateway.channels.slack.is_duplicate", new_callable=AsyncMock, return_value=False), + patch("gateway.channels.slack._mark_thread_engaged", new_callable=AsyncMock), + patch("orchestrator.tasks.handle_message") as celery_mock, + ): + celery_mock.delay = MagicMock() + await _handle_slack_event( + event=event, + say=mock_say, + client=mock_client, + redis=fake_redis, + get_session=_make_mock_session_factory(), + event_type="app_mention", + ) + + # Placeholder must have been posted + mock_client.chat_postMessage.assert_called_once() + call_kwargs = mock_client.chat_postMessage.call_args + placeholder_text = call_kwargs.kwargs.get("text", "") + assert "thinking" in placeholder_text.lower() + + async def test_mention_posts_placeholder_in_thread(self, fake_redis) -> None: + """ + The placeholder must be posted with thread_ts set (in-thread reply). + """ + mock_client = AsyncMock() + mock_client.chat_postMessage = AsyncMock(return_value={"ts": "999.001", "ok": True}) + mock_say = AsyncMock() + + event = _make_mention_event() + event["thread_ts"] = "1711234567.000000" # Set thread context + + with ( + patch("gateway.channels.slack.resolve_tenant", new_callable=AsyncMock, return_value=TENANT_ID), + patch("gateway.channels.slack.is_duplicate", new_callable=AsyncMock, return_value=False), + patch("gateway.channels.slack._mark_thread_engaged", new_callable=AsyncMock), + patch("orchestrator.tasks.handle_message") as celery_mock, + ): + celery_mock.delay = MagicMock() + await _handle_slack_event( + event=event, + say=mock_say, + client=mock_client, + redis=fake_redis, + get_session=_make_mock_session_factory(), + event_type="app_mention", + ) + + call_kwargs = mock_client.chat_postMessage.call_args + # thread_ts must be set in the placeholder post + assert call_kwargs.kwargs.get("thread_ts") is not None + + async def test_mention_dispatches_celery_task(self, fake_redis) -> None: + """ + After the placeholder is posted, the Celery task must be dispatched + with the message data. + """ + mock_client = AsyncMock() + mock_client.chat_postMessage = AsyncMock(return_value={"ts": "999.001", "ok": True}) + mock_say = AsyncMock() + + event = _make_mention_event() + + dispatched_payloads: list[dict] = [] + + def capture_delay(payload: dict) -> MagicMock: + dispatched_payloads.append(payload) + return MagicMock() + + with ( + patch("gateway.channels.slack.resolve_tenant", new_callable=AsyncMock, return_value=TENANT_ID), + patch("gateway.channels.slack.is_duplicate", new_callable=AsyncMock, return_value=False), + patch("gateway.channels.slack._mark_thread_engaged", new_callable=AsyncMock), + patch("orchestrator.tasks.handle_message") as celery_mock, + ): + celery_mock.delay = MagicMock(side_effect=capture_delay) + await _handle_slack_event( + event=event, + say=mock_say, + client=mock_client, + redis=fake_redis, + get_session=_make_mock_session_factory(), + event_type="app_mention", + ) + + assert len(dispatched_payloads) == 1 + payload = dispatched_payloads[0] + assert payload["tenant_id"] == TENANT_ID + assert payload["channel"] == "slack" + # placeholder_ts and channel_id must be present + assert "placeholder_ts" in payload + assert "channel_id" in payload + + async def test_celery_payload_has_placeholder_ts(self, fake_redis) -> None: + """Celery payload must include placeholder_ts for post-LLM chat.update.""" + mock_client = AsyncMock() + placeholder_ts = "1234567890.111111" + mock_client.chat_postMessage = AsyncMock( + return_value={"ts": placeholder_ts, "ok": True} + ) + mock_say = AsyncMock() + + event = _make_mention_event() + dispatched_payloads: list[dict] = [] + + with ( + patch("gateway.channels.slack.resolve_tenant", new_callable=AsyncMock, return_value=TENANT_ID), + patch("gateway.channels.slack.is_duplicate", new_callable=AsyncMock, return_value=False), + patch("gateway.channels.slack._mark_thread_engaged", new_callable=AsyncMock), + patch("orchestrator.tasks.handle_message") as celery_mock, + ): + celery_mock.delay = MagicMock(side_effect=lambda p: dispatched_payloads.append(p)) + await _handle_slack_event( + event=event, + say=mock_say, + client=mock_client, + redis=fake_redis, + get_session=_make_mock_session_factory(), + event_type="app_mention", + ) + + assert dispatched_payloads[0]["placeholder_ts"] == placeholder_ts + + +class TestDMFlow: + """CHAN-02: Direct message event pipeline.""" + + async def test_dm_triggers_same_pipeline(self, fake_redis) -> None: + """ + A DM (channel_type='im') must trigger the same handler pipeline + as an @mention. + """ + mock_client = AsyncMock() + mock_client.chat_postMessage = AsyncMock(return_value={"ts": "999.002", "ok": True}) + mock_say = AsyncMock() + + event = _make_dm_event() + + with ( + patch("gateway.channels.slack.resolve_tenant", new_callable=AsyncMock, return_value=TENANT_ID), + patch("gateway.channels.slack.is_duplicate", new_callable=AsyncMock, return_value=False), + patch("gateway.channels.slack._mark_thread_engaged", new_callable=AsyncMock), + patch("orchestrator.tasks.handle_message") as celery_mock, + ): + celery_mock.delay = MagicMock() + await _handle_slack_event( + event=event, + say=mock_say, + client=mock_client, + redis=fake_redis, + get_session=_make_mock_session_factory(), + event_type="dm", + ) + + # Pipeline should proceed: placeholder posted + Celery dispatched + mock_client.chat_postMessage.assert_called_once() + celery_mock.delay.assert_called_once() + + async def test_dm_payload_channel_is_slack(self, fake_redis) -> None: + """DM normalized message must still have channel=slack.""" + mock_client = AsyncMock() + mock_client.chat_postMessage = AsyncMock(return_value={"ts": "999.002", "ok": True}) + mock_say = AsyncMock() + + event = _make_dm_event() + dispatched_payloads: list[dict] = [] + + with ( + patch("gateway.channels.slack.resolve_tenant", new_callable=AsyncMock, return_value=TENANT_ID), + patch("gateway.channels.slack.is_duplicate", new_callable=AsyncMock, return_value=False), + patch("gateway.channels.slack._mark_thread_engaged", new_callable=AsyncMock), + patch("orchestrator.tasks.handle_message") as celery_mock, + ): + celery_mock.delay = MagicMock(side_effect=lambda p: dispatched_payloads.append(p)) + await _handle_slack_event( + event=event, + say=mock_say, + client=mock_client, + redis=fake_redis, + get_session=_make_mock_session_factory(), + event_type="dm", + ) + + assert dispatched_payloads[0]["channel"] == "slack" + + +class TestBotIgnoring: + """Verify bot messages are silently ignored to prevent infinite loops.""" + + async def test_bot_message_is_ignored(self, fake_redis) -> None: + """ + Events with bot_id must be silently dropped — no placeholder, no Celery. + """ + mock_client = AsyncMock() + mock_say = AsyncMock() + + event = _make_bot_event() + + with patch("orchestrator.tasks.handle_message") as celery_mock: + celery_mock.delay = MagicMock() + await _handle_slack_event( + event=event, + say=mock_say, + client=mock_client, + redis=fake_redis, + get_session=_make_mock_session_factory(), + event_type="dm", + ) + + # Nothing should happen + mock_client.chat_postMessage.assert_not_called() + celery_mock.delay.assert_not_called() + + async def test_bot_message_subtype_is_ignored(self, fake_redis) -> None: + """Events with subtype=bot_message must also be ignored.""" + mock_client = AsyncMock() + mock_say = AsyncMock() + + event = { + "subtype": "bot_message", + "text": "automated message", + "ts": "123.456", + "channel": "C99999", + "_workspace_id": WORKSPACE_ID, + } + + with patch("orchestrator.tasks.handle_message") as celery_mock: + celery_mock.delay = MagicMock() + await _handle_slack_event( + event=event, + say=mock_say, + client=mock_client, + redis=fake_redis, + get_session=_make_mock_session_factory(), + event_type="dm", + ) + + mock_client.chat_postMessage.assert_not_called() + celery_mock.delay.assert_not_called() + + +class TestUnknownWorkspace: + """Verify unknown workspace_id events are silently ignored.""" + + async def test_unknown_workspace_silently_ignored(self, fake_redis) -> None: + """ + If workspace_id maps to no tenant, the event must be silently dropped — + no placeholder, no Celery dispatch, no exception raised. + """ + mock_client = AsyncMock() + mock_say = AsyncMock() + + event = _make_mention_event() + + with ( + patch("gateway.channels.slack.resolve_tenant", new_callable=AsyncMock, return_value=None), + patch("orchestrator.tasks.handle_message") as celery_mock, + ): + celery_mock.delay = MagicMock() + # Must not raise + await _handle_slack_event( + event=event, + say=mock_say, + client=mock_client, + redis=fake_redis, + get_session=_make_mock_session_factory(), + event_type="app_mention", + ) + + mock_client.chat_postMessage.assert_not_called() + celery_mock.delay.assert_not_called() + + +class TestIdempotency: + """Verify duplicate Slack events (retries) are not double-dispatched.""" + + async def test_duplicate_event_is_skipped(self, fake_redis) -> None: + """ + If a message was already processed (Slack retry), no placeholder + is posted and Celery is not called again. + """ + mock_client = AsyncMock() + mock_say = AsyncMock() + + event = _make_mention_event() + + with ( + patch("gateway.channels.slack.resolve_tenant", new_callable=AsyncMock, return_value=TENANT_ID), + patch("gateway.channels.slack.is_duplicate", new_callable=AsyncMock, return_value=True), + patch("orchestrator.tasks.handle_message") as celery_mock, + ): + celery_mock.delay = MagicMock() + await _handle_slack_event( + event=event, + say=mock_say, + client=mock_client, + redis=fake_redis, + get_session=_make_mock_session_factory(), + event_type="app_mention", + ) + + mock_client.chat_postMessage.assert_not_called() + celery_mock.delay.assert_not_called() + + +# --------------------------------------------------------------------------- +# Helper +# --------------------------------------------------------------------------- + +def _make_mock_session_factory(tenant_id: str = TENANT_ID) -> MagicMock: + """Return a mock async context manager factory for DB sessions.""" + mock_session = AsyncMock() + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + mock_factory = MagicMock() + mock_factory.return_value = mock_session + return mock_factory diff --git a/tests/unit/test_ratelimit.py b/tests/unit/test_ratelimit.py new file mode 100644 index 0000000..c437f93 --- /dev/null +++ b/tests/unit/test_ratelimit.py @@ -0,0 +1,157 @@ +""" +Unit tests for the Redis token bucket rate limiter. + +Tests CHAN-05: Rate limiting enforces per-tenant, per-channel limits. + +These tests use fakeredis to run without a live Redis instance. +""" + +from __future__ import annotations + +import asyncio + +import fakeredis +import pytest +import pytest_asyncio + +from router.ratelimit import RateLimitExceeded, check_rate_limit + + +@pytest_asyncio.fixture +async def fake_redis(): + """Provide a fake async Redis client backed by fakeredis.""" + r = fakeredis.aioredis.FakeRedis(decode_responses=True) + yield r + await r.aclose() + + +class TestTokenBucketAllows: + """Tests for requests within the rate limit.""" + + async def test_single_request_allowed(self, fake_redis) -> None: + """A single request is always allowed.""" + result = await check_rate_limit("tenant_a", "slack", fake_redis, limit=30) + assert result is True + + async def test_requests_under_limit_all_allowed(self, fake_redis) -> None: + """All 29 requests within a 30-request limit are allowed.""" + for i in range(29): + result = await check_rate_limit("tenant_a", "slack", fake_redis, limit=30) + assert result is True, f"Request {i + 1} was unexpectedly rejected" + + async def test_exactly_at_limit_allowed(self, fake_redis) -> None: + """The 30th request (exactly at limit) is allowed.""" + for _ in range(30): + result = await check_rate_limit("tenant_a", "slack", fake_redis, limit=30) + assert result is True + + +class TestTokenBucketRejects: + """Tests for requests exceeding the rate limit.""" + + async def test_31st_request_rejected(self, fake_redis) -> None: + """The 31st request in a 30-request window is rejected.""" + # Fill up the bucket + for _ in range(30): + await check_rate_limit("tenant_a", "slack", fake_redis, limit=30) + + # 31st request must raise + with pytest.raises(RateLimitExceeded) as exc_info: + await check_rate_limit("tenant_a", "slack", fake_redis, limit=30) + + assert exc_info.value.tenant_id == "tenant_a" + assert exc_info.value.channel == "slack" + + async def test_rate_limit_exceeded_has_remaining_seconds(self, fake_redis) -> None: + """RateLimitExceeded must expose remaining_seconds attribute.""" + for _ in range(30): + await check_rate_limit("tenant_a", "slack", fake_redis, limit=30, window_seconds=60) + + with pytest.raises(RateLimitExceeded) as exc_info: + await check_rate_limit("tenant_a", "slack", fake_redis, limit=30, window_seconds=60) + + assert isinstance(exc_info.value.remaining_seconds, int) + assert exc_info.value.remaining_seconds >= 0 + + async def test_continued_requests_after_exceeded_also_rejected(self, fake_redis) -> None: + """Additional requests after exceeding the limit continue to be rejected.""" + for _ in range(30): + await check_rate_limit("tenant_a", "slack", fake_redis, limit=30) + + for _ in range(5): + with pytest.raises(RateLimitExceeded): + await check_rate_limit("tenant_a", "slack", fake_redis, limit=30) + + +class TestTokenBucketTenantIsolation: + """Tests that rate limit counters are isolated per tenant.""" + + async def test_tenant_a_limit_independent_of_tenant_b(self, fake_redis) -> None: + """Exhausting tenant A's limit does not affect tenant B.""" + # Exhaust tenant A's limit + for _ in range(30): + await check_rate_limit("tenant_a", "slack", fake_redis, limit=30) + with pytest.raises(RateLimitExceeded): + await check_rate_limit("tenant_a", "slack", fake_redis, limit=30) + + # Tenant B should still be allowed + result = await check_rate_limit("tenant_b", "slack", fake_redis, limit=30) + assert result is True + + async def test_channel_limits_are_independent(self, fake_redis) -> None: + """Exhausting Slack limit does not affect Telegram limit.""" + for _ in range(30): + await check_rate_limit("tenant_a", "slack", fake_redis, limit=30) + with pytest.raises(RateLimitExceeded): + await check_rate_limit("tenant_a", "slack", fake_redis, limit=30) + + # A different channel should still be allowed + result = await check_rate_limit("tenant_a", "telegram", fake_redis, limit=30) + assert result is True + + async def test_multiple_tenants_independent_counters(self, fake_redis) -> None: + """Multiple tenants maintain separate counter keys.""" + tenants = ["tenant_a", "tenant_b", "tenant_c"] + # Each tenant makes 15 requests + for tenant in tenants: + for _ in range(15): + result = await check_rate_limit(tenant, "slack", fake_redis, limit=30) + assert result is True, f"Unexpected rejection for {tenant}" + + # None should be at the limit yet + for tenant in tenants: + result = await check_rate_limit(tenant, "slack", fake_redis, limit=30) + assert result is True + + +class TestTokenBucketWindowReset: + """Tests for rate limit window expiry.""" + + async def test_limit_resets_after_window_expires(self, fake_redis) -> None: + """After the TTL expires, the rate limit resets and requests are allowed again.""" + # Use a very short window (1 second) to test expiry + for _ in range(5): + await check_rate_limit("tenant_a", "slack", fake_redis, limit=5, window_seconds=1) + with pytest.raises(RateLimitExceeded): + await check_rate_limit("tenant_a", "slack", fake_redis, limit=5, window_seconds=1) + + # Manually expire the key to simulate window reset + from shared.redis_keys import rate_limit_key + key = rate_limit_key("tenant_a", "slack") + await fake_redis.delete(key) + + # After key expiry, requests should be allowed again + result = await check_rate_limit("tenant_a", "slack", fake_redis, limit=5, window_seconds=1) + assert result is True + + async def test_rate_limit_key_has_ttl(self, fake_redis) -> None: + """Rate limit key must have a TTL set (window expiry).""" + from shared.redis_keys import rate_limit_key + key = rate_limit_key("tenant_a", "slack") + + await check_rate_limit("tenant_a", "slack", fake_redis, limit=30, window_seconds=60) + + ttl = await fake_redis.ttl(key) + # TTL should be set (positive) — key will expire at end of window + assert ttl > 0 + assert ttl <= 60