From 285ab7af2c959bcbb9fb796f5253a81f8af7e321 Mon Sep 17 00:00:00 2001 From: Patrick Nikoletich Date: Fri, 23 Jan 2026 16:44:42 -0800 Subject: [PATCH] Python steering --- python/copilot/__init__.py | 16 ++ python/copilot/steering.py | 509 +++++++++++++++++++++++++++++++++ python/test_steering.py | 563 +++++++++++++++++++++++++++++++++++++ 3 files changed, 1088 insertions(+) create mode 100644 python/copilot/steering.py create mode 100644 python/test_steering.py diff --git a/python/copilot/__init__.py b/python/copilot/__init__.py index f5961472b..6e8053e49 100644 --- a/python/copilot/__init__.py +++ b/python/copilot/__init__.py @@ -6,6 +6,15 @@ from .client import CopilotClient from .session import CopilotSession +from .steering import ( + ConversationManager, + MessageQueue, + Priority, + QueuedMessage, + QueueFullError, + ShutdownSentinel, + StreamingInputGenerator, +) from .tools import define_tool from .types import ( AzureProviderOptions, @@ -39,6 +48,7 @@ __all__ = [ "AzureProviderOptions", + "ConversationManager", "CopilotClient", "CopilotSession", "ConnectionState", @@ -49,6 +59,7 @@ "MCPRemoteServerConfig", "MCPServerConfig", "MessageOptions", + "MessageQueue", "ModelBilling", "ModelCapabilities", "ModelInfo", @@ -56,11 +67,16 @@ "PermissionHandler", "PermissionRequest", "PermissionRequestResult", + "Priority", "ProviderConfig", + "QueuedMessage", + "QueueFullError", "ResumeSessionConfig", "SessionConfig", "SessionEvent", "SessionMetadata", + "ShutdownSentinel", + "StreamingInputGenerator", "Tool", "ToolHandler", "ToolInvocation", diff --git a/python/copilot/steering.py b/python/copilot/steering.py new file mode 100644 index 000000000..d4576505d --- /dev/null +++ b/python/copilot/steering.py @@ -0,0 +1,509 @@ +""" +Steering - Streaming conversation mode for queuing messages with priority support. + +This module provides the infrastructure for queuing multiple user messages +while the CLI processes previous ones, with support for priority-based ordering +and graceful shutdown. + +Example: + >>> from copilot import CopilotClient + >>> from copilot.steering import ConversationManager, Priority + >>> + >>> async with CopilotClient() as client: + ... session = await client.create_session() + ... manager = ConversationManager(session) + ... + ... # Queue messages - returns immediately even if CLI is busy + ... await manager.queue_message("req-1", "What is Python?", Priority.NORMAL) + ... await manager.queue_message("req-2", "URGENT: Fix the bug!", Priority.URGENT) + ... await manager.queue_message("req-3", "Tell me a joke", Priority.LOW) + ... + ... # Messages processed in order: req-2, req-1, req-3 + ... + ... # Graceful shutdown + ... await manager.stop() +""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass, field +from datetime import datetime +from enum import IntEnum +from typing import Any, AsyncIterator, Callable, Optional + +from .session import CopilotSession +from .types import Attachment + + +class Priority(IntEnum): + """Message priority levels for queue ordering. + + Higher values are processed first. Within the same priority, + messages are processed in FIFO order. + """ + + LOW = 0 + NORMAL = 1 + HIGH = 2 + URGENT = 3 + + +class QueueFullError(Exception): + """Raised when the message queue is full and cannot accept new messages.""" + + pass + + +class ShutdownSentinel: + """Marker to signal generator termination. + + Always sorts last in the priority queue to ensure all pending + messages are processed before shutdown. + """ + + def __lt__(self, other: object) -> bool: + """Always sort after real messages.""" + return False + + def __le__(self, other: object) -> bool: + """Always sort after real messages.""" + return isinstance(other, ShutdownSentinel) + + def __gt__(self, other: object) -> bool: + """Always sort after real messages.""" + return not isinstance(other, ShutdownSentinel) + + def __ge__(self, other: object) -> bool: + """Always sort after real messages.""" + return True + + +# Singleton sentinel instance +SHUTDOWN_SENTINEL = ShutdownSentinel() + + +@dataclass(order=False) +class QueuedMessage: + """Represents a message queued for processing. + + Attributes: + request_id: Unique identifier for this message request. + content: The message content/prompt text. + priority: Processing priority (higher = processed first). + session_id: The session this message belongs to. + queued_at: When the message was queued. + metadata: Additional metadata for the message. + sequence_number: For FIFO ordering within same priority. + attachments: Optional file/directory attachments. + """ + + request_id: str + content: str + priority: Priority + session_id: str + queued_at: datetime = field(default_factory=datetime.now) + metadata: dict[str, Any] = field(default_factory=dict) + sequence_number: int = 0 + attachments: list[Attachment] = field(default_factory=list) + + def __lt__(self, other: object) -> bool: + """Priority queue ordering: higher priority first, then FIFO.""" + if isinstance(other, ShutdownSentinel): + return True # Real messages come before sentinel + if not isinstance(other, QueuedMessage): + return NotImplemented + if self.priority != other.priority: + return self.priority.value > other.priority.value # Higher = first + return self.sequence_number < other.sequence_number # Earlier = first + + def __le__(self, other: object) -> bool: + if isinstance(other, ShutdownSentinel): + return True + if not isinstance(other, QueuedMessage): + return NotImplemented + return self < other or ( + self.priority == other.priority + and self.sequence_number == other.sequence_number + ) + + def __gt__(self, other: object) -> bool: + if isinstance(other, ShutdownSentinel): + return False + if not isinstance(other, QueuedMessage): + return NotImplemented + return not self <= other + + def __ge__(self, other: object) -> bool: + if isinstance(other, ShutdownSentinel): + return False + if not isinstance(other, QueuedMessage): + return NotImplemented + return not self < other + + +class MessageQueue: + """Priority queue for conversation messages. + + Provides non-blocking put() and blocking get() semantics with + priority-based ordering. Thread-safe for use with asyncio. + + Attributes: + max_depth: Maximum number of messages the queue can hold. + """ + + def __init__(self, max_depth: int = 100): + """Initialize the message queue. + + Args: + max_depth: Maximum queue size. Defaults to 100. + """ + self._max_depth = max_depth + self._queue: asyncio.PriorityQueue[QueuedMessage | ShutdownSentinel] = ( + asyncio.PriorityQueue(maxsize=max_depth) + ) + self._shutdown_event = asyncio.Event() + self._sequence_counter = 0 + + @property + def max_depth(self) -> int: + """Maximum queue capacity.""" + return self._max_depth + + def qsize(self) -> int: + """Return the current queue size.""" + return self._queue.qsize() + + def empty(self) -> bool: + """Return True if the queue is empty.""" + return self._queue.empty() + + def full(self) -> bool: + """Return True if the queue is full.""" + return self._queue.full() + + def _next_sequence(self) -> int: + """Get the next sequence number for FIFO ordering.""" + seq = self._sequence_counter + self._sequence_counter += 1 + return seq + + async def put(self, message: QueuedMessage) -> None: + """Add a message to the queue. + + This is non-blocking - it raises QueueFullError if the queue is full + rather than waiting. + + Args: + message: The message to queue. + + Raises: + QueueFullError: If the queue is at max capacity. + """ + # Assign sequence number if not set + if message.sequence_number == 0: + message.sequence_number = self._next_sequence() + + try: + self._queue.put_nowait(message) + except asyncio.QueueFull: + raise QueueFullError(f"Queue full (max={self._max_depth})") + + async def get(self) -> QueuedMessage | ShutdownSentinel: + """Get the next message from the queue. + + Blocks until a message is available or shutdown is signaled. + + Returns: + The next message to process, or ShutdownSentinel if shutting down. + """ + return await self._queue.get() + + def signal_shutdown(self) -> None: + """Signal the generator to terminate. + + This unblocks any waiting get() calls and causes the generator + to terminate gracefully. + """ + self._shutdown_event.set() + try: + self._queue.put_nowait(SHUTDOWN_SENTINEL) + except asyncio.QueueFull: + # Queue is full, but we still set the event + pass + + def is_shutdown(self) -> bool: + """Check if shutdown has been signaled.""" + return self._shutdown_event.is_set() + + +class StreamingInputGenerator: + """Async generator that yields messages from queue to SDK. + + This class bridges the message queue to the SDK's async iterator + interface, yielding formatted message dicts until shutdown. + + Example: + >>> queue = MessageQueue() + >>> generator = StreamingInputGenerator(queue) + >>> async for message in generator: + ... # Process message + ... pass + """ + + def __init__(self, queue: MessageQueue): + """Initialize the generator. + + Args: + queue: The message queue to consume from. + """ + self._queue = queue + + def __aiter__(self) -> AsyncIterator[dict[str, Any]]: + """Return self as the async iterator.""" + return self._generate() + + async def _generate(self) -> AsyncIterator[dict[str, Any]]: + """Yield messages until shutdown sentinel received.""" + while True: + item = await self._queue.get() + + if isinstance(item, ShutdownSentinel): + break + + # Format for SDK consumption + yield { + "type": "user", + "message": { + "role": "user", + "content": item.content, + }, + "metadata": { + "request_id": item.request_id, + "priority": item.priority.name, + "session_id": item.session_id, + "queued_at": item.queued_at.isoformat(), + **item.metadata, + }, + "attachments": item.attachments, + } + + +class ConversationManager: + """Orchestrates message flow from callers to SDK session. + + The ConversationManager provides a high-level interface for queuing + messages while the CLI processes previous ones. It handles: + + - Non-blocking message queuing with priority support + - Automatic session interaction + - Graceful shutdown + + Example: + >>> async with CopilotClient() as client: + ... session = await client.create_session() + ... manager = ConversationManager(session) + ... + ... # Queue messages - returns immediately + ... await manager.queue_message("req-1", "Hello", Priority.NORMAL) + ... await manager.queue_message("req-2", "Urgent!", Priority.URGENT) + ... + ... # Stop when done + ... await manager.stop() + """ + + def __init__( + self, + session: CopilotSession, + max_queue_depth: int = 100, + on_response: Optional[Callable[[dict[str, Any]], None]] = None, + ): + """Initialize the conversation manager. + + Args: + session: The CopilotSession to send messages to. + max_queue_depth: Maximum number of queued messages. Defaults to 100. + on_response: Optional callback for responses/events from the session. + """ + self._session = session + self._max_queue_depth = max_queue_depth + self._on_response = on_response + self._queue: Optional[MessageQueue] = None + self._generator: Optional[StreamingInputGenerator] = None + self._processor_task: Optional[asyncio.Task[None]] = None + self._started = False + self._request_counter = 0 + + @property + def session(self) -> CopilotSession: + """The underlying session.""" + return self._session + + @property + def is_started(self) -> bool: + """Whether the manager has been started.""" + return self._started + + @property + def queue_size(self) -> int: + """Current number of queued messages.""" + return self._queue.qsize() if self._queue else 0 + + def _generate_request_id(self) -> str: + """Generate a unique request ID.""" + self._request_counter += 1 + return f"req-{self._request_counter}" + + async def queue_message( + self, + content: str, + priority: Priority = Priority.NORMAL, + request_id: Optional[str] = None, + attachments: Optional[list[Attachment]] = None, + metadata: Optional[dict[str, Any]] = None, + ) -> str: + """Queue a message for processing. + + This method returns immediately, even if the CLI is busy processing + a previous message. Messages are processed in priority order. + + Args: + content: The message content/prompt. + priority: Processing priority. Defaults to NORMAL. + request_id: Optional custom request ID. Auto-generated if not provided. + attachments: Optional file/directory attachments. + metadata: Optional additional metadata. + + Returns: + The request ID for this message. + + Raises: + QueueFullError: If the queue is at max capacity. + """ + # Auto-start on first message + if not self._started: + await self._start() + + if request_id is None: + request_id = self._generate_request_id() + + msg = QueuedMessage( + request_id=request_id, + content=content, + priority=priority, + session_id=self._session.session_id, + attachments=attachments or [], + metadata=metadata or {}, + ) + await self._queue.put(msg) # type: ignore[union-attr] + return request_id + + async def _start(self) -> None: + """Start the conversation processing loop.""" + if self._started: + return + + self._queue = MessageQueue(max_depth=self._max_queue_depth) + self._generator = StreamingInputGenerator(self._queue) + self._started = True + + # Start the processor task that consumes from the generator + self._processor_task = asyncio.create_task(self._process_messages()) + + async def _process_messages(self) -> None: + """Process messages from the generator, sending them to the session.""" + if self._generator is None: + return + + async for message_dict in self._generator: + try: + # Extract message details + content = message_dict["message"]["content"] + attachments = message_dict.get("attachments", []) + request_id = message_dict["metadata"]["request_id"] + + # Send to session - this is non-blocking in the SDK + # The session.send returns a message ID + await self._session.send( + { + "prompt": content, + "attachments": attachments if attachments else None, + } + ) + + # If we have a response callback, we could wire it up here + # For now, the caller should use session.on() for events + + except Exception as e: + # Log error but continue processing + print(f"Error processing message {message_dict.get('metadata', {}).get('request_id', 'unknown')}: {e}") + + async def stop(self, timeout: Optional[float] = None) -> None: + """Gracefully stop the conversation manager. + + Signals shutdown and waits for pending messages to be processed. + + Args: + timeout: Maximum time to wait for pending messages. Defaults to None (wait forever). + """ + if not self._started or self._queue is None: + return + + # Signal shutdown + self._queue.signal_shutdown() + + # Wait for processor task to complete + if self._processor_task is not None: + try: + if timeout is not None: + await asyncio.wait_for(self._processor_task, timeout=timeout) + else: + await self._processor_task + except asyncio.TimeoutError: + self._processor_task.cancel() + try: + await self._processor_task + except asyncio.CancelledError: + pass + except asyncio.CancelledError: + pass + + self._started = False + self._queue = None + self._generator = None + self._processor_task = None + + async def clear_queue(self) -> int: + """Clear all pending messages from the queue. + + Returns: + The number of messages that were cleared. + """ + if self._queue is None: + return 0 + + count = 0 + while not self._queue.empty(): + try: + # We need to drain the queue + # Create a new empty queue and swap + item = self._queue._queue.get_nowait() + if not isinstance(item, ShutdownSentinel): + count += 1 + except asyncio.QueueEmpty: + break + + return count + + async def __aenter__(self) -> "ConversationManager": + """Async context manager entry.""" + return self + + async def __aexit__( + self, + exc_type: Optional[type], + exc_val: Optional[BaseException], + exc_tb: Optional[Any], + ) -> None: + """Async context manager exit - ensures clean shutdown.""" + await self.stop() diff --git a/python/test_steering.py b/python/test_steering.py new file mode 100644 index 000000000..2bb71f046 --- /dev/null +++ b/python/test_steering.py @@ -0,0 +1,563 @@ +""" +Unit tests for the steering module. +""" + +import asyncio +from datetime import datetime +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from copilot.steering import ( + ConversationManager, + MessageQueue, + Priority, + QueuedMessage, + QueueFullError, + ShutdownSentinel, + StreamingInputGenerator, + SHUTDOWN_SENTINEL, +) + + +class TestPriority: + def test_priority_ordering(self): + """Test that priority values are ordered correctly.""" + assert Priority.LOW.value == 0 + assert Priority.NORMAL.value == 1 + assert Priority.HIGH.value == 2 + assert Priority.URGENT.value == 3 + + assert Priority.URGENT > Priority.HIGH > Priority.NORMAL > Priority.LOW + + +class TestQueuedMessage: + def test_creation(self): + """Test basic message creation.""" + msg = QueuedMessage( + request_id="test-1", + content="Hello", + priority=Priority.NORMAL, + session_id="session-1", + ) + assert msg.request_id == "test-1" + assert msg.content == "Hello" + assert msg.priority == Priority.NORMAL + assert msg.session_id == "session-1" + assert isinstance(msg.queued_at, datetime) + + def test_priority_comparison(self): + """Test that higher priority messages sort first.""" + low = QueuedMessage( + request_id="low", + content="Low", + priority=Priority.LOW, + session_id="s", + sequence_number=1, + ) + normal = QueuedMessage( + request_id="normal", + content="Normal", + priority=Priority.NORMAL, + session_id="s", + sequence_number=2, + ) + urgent = QueuedMessage( + request_id="urgent", + content="Urgent", + priority=Priority.URGENT, + session_id="s", + sequence_number=3, + ) + + # Higher priority should sort first (be "less than") + assert urgent < normal < low + assert not low < urgent + + def test_fifo_within_same_priority(self): + """Test FIFO ordering for messages with same priority.""" + first = QueuedMessage( + request_id="first", + content="First", + priority=Priority.NORMAL, + session_id="s", + sequence_number=1, + ) + second = QueuedMessage( + request_id="second", + content="Second", + priority=Priority.NORMAL, + session_id="s", + sequence_number=2, + ) + + # Earlier sequence should sort first + assert first < second + assert not second < first + + def test_comparison_with_sentinel(self): + """Test that messages sort before shutdown sentinel.""" + msg = QueuedMessage( + request_id="msg", + content="Test", + priority=Priority.URGENT, + session_id="s", + ) + + assert msg < SHUTDOWN_SENTINEL + assert not SHUTDOWN_SENTINEL < msg + + +class TestShutdownSentinel: + def test_singleton(self): + """Test that SHUTDOWN_SENTINEL is used consistently.""" + assert isinstance(SHUTDOWN_SENTINEL, ShutdownSentinel) + + def test_comparison_with_messages(self): + """Test sentinel always sorts last.""" + msg = QueuedMessage( + request_id="msg", + content="Test", + priority=Priority.LOW, + session_id="s", + ) + + assert msg < SHUTDOWN_SENTINEL + assert SHUTDOWN_SENTINEL > msg + assert not SHUTDOWN_SENTINEL < msg + + def test_comparison_with_itself(self): + """Test sentinel comparison with itself.""" + sentinel1 = ShutdownSentinel() + sentinel2 = ShutdownSentinel() + + assert not sentinel1 < sentinel2 + assert sentinel1 <= sentinel2 + assert sentinel1 >= sentinel2 + + +class TestMessageQueue: + @pytest.mark.asyncio + async def test_put_and_get(self): + """Test basic put and get operations.""" + queue = MessageQueue(max_depth=10) + + msg = QueuedMessage( + request_id="test", + content="Hello", + priority=Priority.NORMAL, + session_id="session-1", + ) + + await queue.put(msg) + assert queue.qsize() == 1 + + result = await queue.get() + assert result == msg + assert queue.qsize() == 0 + + @pytest.mark.asyncio + async def test_priority_ordering(self): + """Test that messages are returned in priority order.""" + queue = MessageQueue(max_depth=10) + + low = QueuedMessage( + request_id="low", + content="Low", + priority=Priority.LOW, + session_id="s", + ) + urgent = QueuedMessage( + request_id="urgent", + content="Urgent", + priority=Priority.URGENT, + session_id="s", + ) + normal = QueuedMessage( + request_id="normal", + content="Normal", + priority=Priority.NORMAL, + session_id="s", + ) + + # Add in arbitrary order + await queue.put(low) + await queue.put(urgent) + await queue.put(normal) + + # Should come out in priority order + assert (await queue.get()).request_id == "urgent" + assert (await queue.get()).request_id == "normal" + assert (await queue.get()).request_id == "low" + + @pytest.mark.asyncio + async def test_queue_full_error(self): + """Test that QueueFullError is raised when queue is full.""" + queue = MessageQueue(max_depth=2) + + msg1 = QueuedMessage( + request_id="1", content="1", priority=Priority.NORMAL, session_id="s" + ) + msg2 = QueuedMessage( + request_id="2", content="2", priority=Priority.NORMAL, session_id="s" + ) + msg3 = QueuedMessage( + request_id="3", content="3", priority=Priority.NORMAL, session_id="s" + ) + + await queue.put(msg1) + await queue.put(msg2) + + with pytest.raises(QueueFullError): + await queue.put(msg3) + + @pytest.mark.asyncio + async def test_shutdown_signal(self): + """Test shutdown signaling.""" + queue = MessageQueue() + + assert not queue.is_shutdown() + queue.signal_shutdown() + assert queue.is_shutdown() + + # Should be able to get the sentinel + result = await queue.get() + assert isinstance(result, ShutdownSentinel) + + @pytest.mark.asyncio + async def test_empty_and_full(self): + """Test empty and full properties.""" + queue = MessageQueue(max_depth=2) + + assert queue.empty() + assert not queue.full() + + msg = QueuedMessage( + request_id="1", content="1", priority=Priority.NORMAL, session_id="s" + ) + await queue.put(msg) + + assert not queue.empty() + assert not queue.full() + + msg2 = QueuedMessage( + request_id="2", content="2", priority=Priority.NORMAL, session_id="s" + ) + await queue.put(msg2) + + assert queue.full() + + +class TestStreamingInputGenerator: + @pytest.mark.asyncio + async def test_yields_messages(self): + """Test that generator yields messages in correct format.""" + queue = MessageQueue() + generator = StreamingInputGenerator(queue) + + msg = QueuedMessage( + request_id="test-1", + content="Hello world", + priority=Priority.NORMAL, + session_id="session-123", + metadata={"custom": "value"}, + ) + await queue.put(msg) + queue.signal_shutdown() + + messages = [] + async for message in generator: + messages.append(message) + + assert len(messages) == 1 + assert messages[0]["type"] == "user" + assert messages[0]["message"]["role"] == "user" + assert messages[0]["message"]["content"] == "Hello world" + assert messages[0]["metadata"]["request_id"] == "test-1" + assert messages[0]["metadata"]["priority"] == "NORMAL" + assert messages[0]["metadata"]["session_id"] == "session-123" + assert messages[0]["metadata"]["custom"] == "value" + + @pytest.mark.asyncio + async def test_stops_on_shutdown(self): + """Test that generator stops when shutdown sentinel is received.""" + queue = MessageQueue() + generator = StreamingInputGenerator(queue) + + msg1 = QueuedMessage( + request_id="1", content="First", priority=Priority.NORMAL, session_id="s" + ) + msg2 = QueuedMessage( + request_id="2", content="Second", priority=Priority.NORMAL, session_id="s" + ) + await queue.put(msg1) + await queue.put(msg2) + queue.signal_shutdown() + + messages = [] + async for message in generator: + messages.append(message) + + assert len(messages) == 2 + + @pytest.mark.asyncio + async def test_priority_order_preserved(self): + """Test that messages come out in priority order.""" + queue = MessageQueue() + generator = StreamingInputGenerator(queue) + + low = QueuedMessage( + request_id="low", content="Low", priority=Priority.LOW, session_id="s" + ) + urgent = QueuedMessage( + request_id="urgent", + content="Urgent", + priority=Priority.URGENT, + session_id="s", + ) + + await queue.put(low) + await queue.put(urgent) + queue.signal_shutdown() + + messages = [] + async for message in generator: + messages.append(message) + + assert messages[0]["metadata"]["request_id"] == "urgent" + assert messages[1]["metadata"]["request_id"] == "low" + + +class TestConversationManager: + @pytest.mark.asyncio + async def test_queue_message(self): + """Test queuing messages.""" + mock_session = MagicMock() + mock_session.session_id = "test-session" + mock_session.send = AsyncMock(return_value="msg-id") + + manager = ConversationManager(mock_session) + + request_id = await manager.queue_message("Hello", Priority.NORMAL) + assert request_id == "req-1" + assert manager.is_started + assert manager.queue_size >= 0 # May have been processed already + + await manager.stop() + + @pytest.mark.asyncio + async def test_auto_start(self): + """Test that manager auto-starts on first message.""" + mock_session = MagicMock() + mock_session.session_id = "test-session" + mock_session.send = AsyncMock(return_value="msg-id") + + manager = ConversationManager(mock_session) + assert not manager.is_started + + await manager.queue_message("Hello", Priority.NORMAL) + assert manager.is_started + + await manager.stop() + + @pytest.mark.asyncio + async def test_custom_request_id(self): + """Test using custom request ID.""" + mock_session = MagicMock() + mock_session.session_id = "test-session" + mock_session.send = AsyncMock(return_value="msg-id") + + manager = ConversationManager(mock_session) + + request_id = await manager.queue_message( + "Hello", Priority.NORMAL, request_id="custom-123" + ) + assert request_id == "custom-123" + + await manager.stop() + + @pytest.mark.asyncio + async def test_stop_without_start(self): + """Test that stop() works even if never started.""" + mock_session = MagicMock() + mock_session.session_id = "test-session" + + manager = ConversationManager(mock_session) + await manager.stop() # Should not raise + + @pytest.mark.asyncio + async def test_context_manager(self): + """Test async context manager usage.""" + mock_session = MagicMock() + mock_session.session_id = "test-session" + mock_session.send = AsyncMock(return_value="msg-id") + + async with ConversationManager(mock_session) as manager: + await manager.queue_message("Hello", Priority.NORMAL) + assert manager.is_started + + # Should be stopped after exiting context + assert not manager.is_started + + @pytest.mark.asyncio + async def test_message_sends_to_session(self): + """Test that queued messages are sent to session.""" + mock_session = MagicMock() + mock_session.session_id = "test-session" + mock_session.send = AsyncMock(return_value="msg-id") + + manager = ConversationManager(mock_session) + + await manager.queue_message("Hello world", Priority.NORMAL) + + # Give the processor task time to run + await asyncio.sleep(0.1) + + # Verify send was called + mock_session.send.assert_called() + call_args = mock_session.send.call_args[0][0] + assert call_args["prompt"] == "Hello world" + + await manager.stop() + + @pytest.mark.asyncio + async def test_priority_processing_order(self): + """Test that messages are processed in priority order.""" + mock_session = MagicMock() + mock_session.session_id = "test-session" + + processed_prompts = [] + + async def capture_send(options): + processed_prompts.append(options["prompt"]) + return "msg-id" + + mock_session.send = capture_send + + manager = ConversationManager(mock_session) + + # Queue messages in arbitrary order + await manager.queue_message("Low priority", Priority.LOW) + await manager.queue_message("Urgent!", Priority.URGENT) + await manager.queue_message("Normal", Priority.NORMAL) + + # Give processor time to work + await asyncio.sleep(0.2) + + await manager.stop() + + # Should be processed in priority order + assert processed_prompts == ["Urgent!", "Normal", "Low priority"] + + @pytest.mark.asyncio + async def test_queue_full_error(self): + """Test that QueueFullError is raised when queue is full.""" + mock_session = MagicMock() + mock_session.session_id = "test-session" + + # Use a blocking send to keep the queue from draining + send_started = asyncio.Event() + send_continue = asyncio.Event() + + async def blocking_send(options): + send_started.set() + await send_continue.wait() + return "msg-id" + + mock_session.send = blocking_send + + manager = ConversationManager(mock_session, max_queue_depth=2) + + # Queue first message - this will be picked up by processor immediately + await manager.queue_message("1", Priority.NORMAL) + + # Wait for send to start (message 1 is now being processed) + await send_started.wait() + + # Now queue 2 more - these should fill the queue + await manager.queue_message("2", Priority.NORMAL) + await manager.queue_message("3", Priority.NORMAL) + + # This should fail since queue is full and send is blocked + with pytest.raises(QueueFullError): + await manager.queue_message("4", Priority.NORMAL) + + # Unblock and cleanup + send_continue.set() + await manager.stop(timeout=1.0) + + @pytest.mark.asyncio + async def test_attachments_forwarded(self): + """Test that attachments are forwarded to session.""" + mock_session = MagicMock() + mock_session.session_id = "test-session" + mock_session.send = AsyncMock(return_value="msg-id") + + manager = ConversationManager(mock_session) + + attachments = [{"type": "file", "path": "/test/file.py"}] + await manager.queue_message( + "Check this file", Priority.NORMAL, attachments=attachments + ) + + await asyncio.sleep(0.1) + + mock_session.send.assert_called() + call_args = mock_session.send.call_args[0][0] + assert call_args["attachments"] == attachments + + await manager.stop() + + +class TestYieldWhileProcessing: + """Tests to verify the SDK can accept yields during processing (spec requirement).""" + + @pytest.mark.asyncio + async def test_yield_while_processing(self): + """Verify that new messages can be queued while previous ones are processing.""" + mock_session = MagicMock() + mock_session.session_id = "test-session" + + first_processing_started = asyncio.Event() + first_processing_done = asyncio.Event() + messages_yielded = [] + call_count = 0 + + async def slow_send(options): + nonlocal call_count + call_count += 1 + messages_yielded.append(options["prompt"]) + + if call_count == 1: + first_processing_started.set() + # Wait for the test to queue more messages + await first_processing_done.wait() + return "msg-id" + + mock_session.send = slow_send + + manager = ConversationManager(mock_session) + + # Queue first message + await manager.queue_message("message 1", Priority.NORMAL) + + # Wait for processing to start + await first_processing_started.wait() + + # These yields must succeed while msg1 is still processing + await manager.queue_message("message 2", Priority.NORMAL) + await manager.queue_message("message 3", Priority.NORMAL) + + # Verify queue accepted the messages (non-blocking put) + assert manager.queue_size >= 2 + + # Let first processing complete + first_processing_done.set() + + # Give time for all messages to process + await asyncio.sleep(0.3) + + await manager.stop() + + # All messages should have been processed + assert messages_yielded == ["message 1", "message 2", "message 3"]