From 2e83485810e08c23128b963b09dec54fad05537a Mon Sep 17 00:00:00 2001 From: Andrej Simurka Date: Mon, 2 Feb 2026 17:02:04 +0100 Subject: [PATCH] Added tool calls and timestamps into turn history --- docs/openapi.json | 158 ++- src/app/endpoints/conversations.py | 390 ------ ...onversations_v3.py => conversations_v1.py} | 130 +- src/app/endpoints/conversations_v2.py | 56 +- src/app/endpoints/query.py | 46 +- src/app/routers.py | 5 +- src/models/database/conversations.py | 32 +- src/models/responses.py | 98 +- src/utils/conversations.py | 382 ++++++ src/utils/endpoints.py | 4 +- .../features/conversation_cache_v2.feature | 10 +- tests/e2e/features/conversations.feature | 13 +- tests/integration/test_openapi_json.py | 4 +- .../unit/app/endpoints/test_conversations.py | 1158 ++++++++++++----- .../app/endpoints/test_conversations_v2.py | 246 +--- tests/unit/app/test_routers.py | 6 +- .../responses/test_successful_responses.py | 8 +- tests/unit/utils/test_conversations.py | 722 ++++++++++ 18 files changed, 2386 insertions(+), 1082 deletions(-) delete mode 100644 src/app/endpoints/conversations.py rename src/app/endpoints/{conversations_v3.py => conversations_v1.py} (84%) create mode 100644 src/utils/conversations.py create mode 100644 tests/unit/utils/test_conversations.py diff --git a/docs/openapi.json b/docs/openapi.json index 8954eef46..29b43630e 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -2445,26 +2445,6 @@ } } } - }, - "503": { - "description": "Service unavailable", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ServiceUnavailableResponse" - }, - "examples": { - "llama stack": { - "value": { - "detail": { - "cause": "Connection error while trying to reach backend service.", - "response": "Unable to connect to Llama Stack" - } - } - } - } - } - } } } } @@ -2510,7 +2490,11 @@ "type": "assistant" } ], - "started_at": "2024-01-01T00:01:00Z" + "model": "gpt-4o-mini", + "provider": "openai", + "started_at": "2024-01-01T00:01:00Z", + "tool_calls": [], + "tool_results": [] } ], "conversation_id": "123e4567-e89b-12d3-a456-426614174000" @@ -3196,7 +3180,11 @@ "type": "assistant" } ], - "started_at": "2024-01-01T00:01:00Z" + "model": "gpt-4o-mini", + "provider": "openai", + "started_at": "2024-01-01T00:01:00Z", + "tool_calls": [], + "tool_results": [] } ], "conversation_id": "123e4567-e89b-12d3-a456-426614174000" @@ -4290,7 +4278,7 @@ ], "summary": "Handle A2A Jsonrpc", "description": "Handle A2A JSON-RPC requests following the A2A protocol specification.\n\nThis endpoint uses the DefaultRequestHandler from the A2A SDK to handle\nall JSON-RPC requests including message/send, message/stream, etc.\n\nThe A2A SDK application is created per-request to include authentication\ncontext while still leveraging FastAPI's authorization middleware.\n\nAutomatically detects streaming requests (message/stream JSON-RPC method)\nand returns a StreamingResponse to enable real-time chunk delivery.\n\nArgs:\n request: FastAPI request object\n auth: Authentication tuple\n mcp_headers: MCP headers for context propagation\n\nReturns:\n JSON-RPC response or streaming response", - "operationId": "handle_a2a_jsonrpc_a2a_post", + "operationId": "handle_a2a_jsonrpc_a2a_get", "responses": { "200": { "description": "Successful Response", @@ -4308,7 +4296,7 @@ ], "summary": "Handle A2A Jsonrpc", "description": "Handle A2A JSON-RPC requests following the A2A protocol specification.\n\nThis endpoint uses the DefaultRequestHandler from the A2A SDK to handle\nall JSON-RPC requests including message/send, message/stream, etc.\n\nThe A2A SDK application is created per-request to include authentication\ncontext while still leveraging FastAPI's authorization middleware.\n\nAutomatically detects streaming requests (message/stream JSON-RPC method)\nand returns a StreamingResponse to enable real-time chunk delivery.\n\nArgs:\n request: FastAPI request object\n auth: Authentication tuple\n mcp_headers: MCP headers for context propagation\n\nReturns:\n JSON-RPC response or streaming response", - "operationId": "handle_a2a_jsonrpc_a2a_post", + "operationId": "handle_a2a_jsonrpc_a2a_get", "responses": { "200": { "description": "Successful Response", @@ -5924,8 +5912,7 @@ }, "chat_history": { "items": { - "additionalProperties": true, - "type": "object" + "$ref": "#/components/schemas/ConversationTurn" }, "type": "array", "title": "Chat History", @@ -5943,7 +5930,11 @@ "type": "assistant" } ], - "started_at": "2024-01-01T00:01:00Z" + "model": "gpt-4o-mini", + "provider": "openai", + "started_at": "2024-01-01T00:01:00Z", + "tool_calls": [], + "tool_results": [] } ] } @@ -5954,7 +5945,7 @@ "chat_history" ], "title": "ConversationResponse", - "description": "Model representing a response for retrieving a conversation.\n\nAttributes:\n conversation_id: The conversation ID (UUID).\n chat_history: The simplified chat history as a list of conversation turns.\n\nExample:\n ```python\n conversation_response = ConversationResponse(\n conversation_id=\"123e4567-e89b-12d3-a456-426614174000\",\n chat_history=[\n {\n \"messages\": [\n {\"content\": \"Hello\", \"type\": \"user\"},\n {\"content\": \"Hi there!\", \"type\": \"assistant\"}\n ],\n \"started_at\": \"2024-01-01T00:01:00Z\",\n \"completed_at\": \"2024-01-01T00:01:05Z\"\n }\n ]\n )\n ```", + "description": "Model representing a response for retrieving a conversation.\n\nAttributes:\n conversation_id: The conversation ID (UUID).\n chat_history: The chat history as a list of conversation turns.", "examples": [ { "chat_history": [ @@ -5970,13 +5961,86 @@ "type": "assistant" } ], - "started_at": "2024-01-01T00:01:00Z" + "model": "gpt-4o-mini", + "provider": "openai", + "started_at": "2024-01-01T00:01:00Z", + "tool_calls": [], + "tool_results": [] } ], "conversation_id": "123e4567-e89b-12d3-a456-426614174000" } ] }, + "ConversationTurn": { + "properties": { + "messages": { + "items": { + "$ref": "#/components/schemas/Message" + }, + "type": "array", + "title": "Messages", + "description": "List of messages in this turn" + }, + "tool_calls": { + "items": { + "$ref": "#/components/schemas/ToolCallSummary" + }, + "type": "array", + "title": "Tool Calls", + "description": "List of tool calls made in this turn" + }, + "tool_results": { + "items": { + "$ref": "#/components/schemas/ToolResultSummary" + }, + "type": "array", + "title": "Tool Results", + "description": "List of tool results from this turn" + }, + "provider": { + "type": "string", + "title": "Provider", + "description": "Provider identifier used for this turn", + "examples": [ + "openai" + ] + }, + "model": { + "type": "string", + "title": "Model", + "description": "Model identifier used for this turn", + "examples": [ + "gpt-4o-mini" + ] + }, + "started_at": { + "type": "string", + "title": "Started At", + "description": "ISO 8601 timestamp when the turn started", + "examples": [ + "2024-01-01T00:01:00Z" + ] + }, + "completed_at": { + "type": "string", + "title": "Completed At", + "description": "ISO 8601 timestamp when the turn completed", + "examples": [ + "2024-01-01T00:01:05Z" + ] + } + }, + "type": "object", + "required": [ + "provider", + "model", + "started_at", + "completed_at" + ], + "title": "ConversationTurn", + "description": "Model representing a single conversation turn.\n\nAttributes:\n messages: List of messages in this turn.\n tool_calls: List of tool calls made in this turn.\n tool_results: List of tool results from this turn.\n provider: Provider identifier used for this turn.\n model: Model identifier used for this turn.\n started_at: ISO 8601 timestamp when the turn started.\n completed_at: ISO 8601 timestamp when the turn completed." + }, "ConversationUpdateRequest": { "properties": { "topic_summary": { @@ -7029,6 +7093,42 @@ "title": "MCPServerAuthInfo", "description": "Information about MCP server client authentication options." }, + "Message": { + "properties": { + "content": { + "type": "string", + "title": "Content", + "description": "The message content", + "examples": [ + "Hello, how can I help you?" + ] + }, + "type": { + "type": "string", + "enum": [ + "user", + "assistant", + "system", + "developer" + ], + "title": "Type", + "description": "The type of message", + "examples": [ + "user", + "assistant", + "system", + "developer" + ] + } + }, + "type": "object", + "required": [ + "content", + "type" + ], + "title": "Message", + "description": "Model representing a message in a conversation turn.\n\nAttributes:\n content: The message content.\n type: The type of message." + }, "ModelContextProtocolServer": { "properties": { "name": { diff --git a/src/app/endpoints/conversations.py b/src/app/endpoints/conversations.py deleted file mode 100644 index 1e9891370..000000000 --- a/src/app/endpoints/conversations.py +++ /dev/null @@ -1,390 +0,0 @@ -"""Handler for REST API calls to manage conversation history.""" - -import logging -from typing import Any - -from fastapi import APIRouter, Depends, HTTPException, Request -from llama_stack_client import APIConnectionError, NotFoundError -from sqlalchemy.exc import SQLAlchemyError - -from app.database import get_session -from authentication import get_auth_dependency -from authorization.middleware import authorize -from client import AsyncLlamaStackClientHolder -from configuration import configuration -from models.config import Action -from models.database.conversations import UserConversation -from models.responses import ( - BadRequestResponse, - ConversationDeleteResponse, - ConversationDetails, - ConversationResponse, - ConversationsListResponse, - ForbiddenResponse, - InternalServerErrorResponse, - NotFoundResponse, - ServiceUnavailableResponse, - UnauthorizedResponse, -) -from utils.endpoints import ( - can_access_conversation, - check_configuration_loaded, - delete_conversation, - retrieve_conversation, -) -from utils.suid import check_suid - -logger = logging.getLogger("app.endpoints.handlers") -router = APIRouter(tags=["conversations"]) - - -conversation_get_responses: dict[int | str, dict[str, Any]] = { - 200: ConversationResponse.openapi_response(), - 400: BadRequestResponse.openapi_response(), - 401: UnauthorizedResponse.openapi_response( - examples=["missing header", "missing token"] - ), - 403: ForbiddenResponse.openapi_response(examples=["conversation read", "endpoint"]), - 404: NotFoundResponse.openapi_response(examples=["conversation"]), - 500: InternalServerErrorResponse.openapi_response( - examples=["database", "configuration"] - ), - 503: ServiceUnavailableResponse.openapi_response(), -} - -conversation_delete_responses: dict[int | str, dict[str, Any]] = { - 200: ConversationDeleteResponse.openapi_response(), - 400: BadRequestResponse.openapi_response(), - 401: UnauthorizedResponse.openapi_response( - examples=["missing header", "missing token"] - ), - 403: ForbiddenResponse.openapi_response( - examples=["conversation delete", "endpoint"] - ), - 404: NotFoundResponse.openapi_response(examples=["conversation"]), - 500: InternalServerErrorResponse.openapi_response( - examples=["database", "configuration"] - ), - 503: ServiceUnavailableResponse.openapi_response(), -} - -conversations_list_responses: dict[int | str, dict[str, Any]] = { - 200: ConversationsListResponse.openapi_response(), - 401: UnauthorizedResponse.openapi_response( - examples=["missing header", "missing token"] - ), - 403: ForbiddenResponse.openapi_response(examples=["endpoint"]), - 500: InternalServerErrorResponse.openapi_response( - examples=["database", "configuration"] - ), - 503: ServiceUnavailableResponse.openapi_response(), -} - - -def simplify_session_data(session_data: dict) -> list[dict[str, Any]]: - """Simplify session data to include only essential conversation information. - - Args: - session_data: The full session data dict from llama-stack - - Returns: - Simplified session data with only input_messages and output_message per turn - """ - # Create simplified structure - chat_history = [] - - # Extract only essential data from each turn - for turn in session_data.get("turns", []): - # Clean up input messages - cleaned_messages = [] - for msg in turn.get("input_messages", []): - cleaned_msg = { - "content": msg.get("content"), - "type": msg.get("role"), # Rename role to type - } - cleaned_messages.append(cleaned_msg) - - # Clean up output message - output_msg = turn.get("output_message", {}) - cleaned_messages.append( - { - "content": output_msg.get("content"), - "type": output_msg.get("role"), # Rename role to type - } - ) - - simplified_turn = { - "messages": cleaned_messages, - "started_at": turn.get("started_at"), - "completed_at": turn.get("completed_at"), - } - chat_history.append(simplified_turn) - - return chat_history - - -@router.get("/conversations", responses=conversations_list_responses) -@authorize(Action.LIST_CONVERSATIONS) -async def get_conversations_list_endpoint_handler( - request: Request, - auth: Any = Depends(get_auth_dependency()), -) -> ConversationsListResponse: - """Handle request to retrieve all conversations for the authenticated user.""" - check_configuration_loaded(configuration) - - user_id = auth[0] - - logger.info("Retrieving conversations for user %s", user_id) - - with get_session() as session: - try: - query = session.query(UserConversation) - - filtered_query = ( - query - if Action.LIST_OTHERS_CONVERSATIONS in request.state.authorized_actions - else query.filter_by(user_id=user_id) - ) - - user_conversations = filtered_query.all() - - # Return conversation summaries with metadata - conversations = [ - ConversationDetails( - conversation_id=conv.id, - created_at=conv.created_at.isoformat() if conv.created_at else None, - last_message_at=( - conv.last_message_at.isoformat() - if conv.last_message_at - else None - ), - message_count=conv.message_count, - last_used_model=conv.last_used_model, - last_used_provider=conv.last_used_provider, - topic_summary=conv.topic_summary, - ) - for conv in user_conversations - ] - - logger.info( - "Found %d conversations for user %s", len(conversations), user_id - ) - - return ConversationsListResponse(conversations=conversations) - - except SQLAlchemyError as e: - logger.exception( - "Error retrieving conversations for user %s: %s", user_id, e - ) - response = InternalServerErrorResponse.database_error() - raise HTTPException(**response.model_dump()) from e - - -@router.get("/conversations/{conversation_id}", responses=conversation_get_responses) -@authorize(Action.GET_CONVERSATION) -async def get_conversation_endpoint_handler( - request: Request, - conversation_id: str, - auth: Any = Depends(get_auth_dependency()), -) -> ConversationResponse: - """ - Handle request to retrieve a conversation by ID. - - Retrieve a conversation's chat history by its ID. Then fetches - the conversation session from the Llama Stack backend, - simplifies the session data to essential chat history, and - returns it in a structured response. Raises HTTP 400 for - invalid IDs, 404 if not found, 503 if the backend is - unavailable, and 500 for unexpected errors. - - Parameters: - conversation_id (str): Unique identifier of the conversation to retrieve. - - Returns: - ConversationResponse: Structured response containing the conversation - ID and simplified chat history. - """ - check_configuration_loaded(configuration) - - # Validate conversation ID format - if not check_suid(conversation_id): - logger.error("Invalid conversation ID format: %s", conversation_id) - response = BadRequestResponse( - resource="conversation", resource_id=conversation_id - ) - raise HTTPException(**response.model_dump()) - - user_id = auth[0] - if not can_access_conversation( - conversation_id, - user_id, - others_allowed=( - Action.READ_OTHERS_CONVERSATIONS in request.state.authorized_actions - ), - ): - logger.warning( - "User %s attempted to read conversation %s they don't have access to", - user_id, - conversation_id, - ) - response = ForbiddenResponse.conversation( - action="read", resource_id=conversation_id, user_id=user_id - ) - raise HTTPException(**response.model_dump()) - - # If reached this, user is authorized to retreive this conversation - conversation = retrieve_conversation(conversation_id) - if conversation is None: - response = NotFoundResponse( - resource="conversation", resource_id=conversation_id - ) - raise HTTPException(**response.model_dump()) - - agent_id = conversation_id - logger.info("Retrieving conversation %s", conversation_id) - - try: - client = AsyncLlamaStackClientHolder().get_client() - - agent_sessions = (await client.agents.session.list(agent_id=agent_id)).data - if not agent_sessions: - logger.error("No sessions found for conversation %s", conversation_id) - response = NotFoundResponse( - resource="conversation", resource_id=conversation_id - ) - raise HTTPException(**response.model_dump()) - session_id = str(agent_sessions[0].get("session_id")) - - session_response = await client.agents.session.retrieve( - agent_id=agent_id, session_id=session_id - ) - session_data = session_response.model_dump() - - logger.info("Successfully retrieved conversation %s", conversation_id) - - # Simplify the session data to include only essential conversation information - chat_history = simplify_session_data(session_data) - - return ConversationResponse( - conversation_id=conversation_id, - chat_history=chat_history, - ) - - except APIConnectionError as e: - logger.error("Unable to connect to Llama Stack: %s", e) - response = ServiceUnavailableResponse(backend_name="Llama Stack", cause=str(e)) - raise HTTPException(**response.model_dump()) from e - - except NotFoundError as e: - logger.error("Conversation not found: %s", e) - response = NotFoundResponse( - resource="conversation", resource_id=conversation_id - ) - raise HTTPException(**response.model_dump()) from e - - except SQLAlchemyError as e: - logger.exception("Error retrieving conversation %s: %s", conversation_id, e) - response = InternalServerErrorResponse.database_error() - raise HTTPException(**response.model_dump()) from e - - -@router.delete( - "/conversations/{conversation_id}", responses=conversation_delete_responses -) -@authorize(Action.DELETE_CONVERSATION) -async def delete_conversation_endpoint_handler( - request: Request, - conversation_id: str, - auth: Any = Depends(get_auth_dependency()), -) -> ConversationDeleteResponse: - """ - Handle request to delete a conversation by ID. - - Validates the conversation ID format and attempts to delete the - corresponding session from the Llama Stack backend. Raises HTTP - errors for invalid IDs, not found conversations, connection - issues, or unexpected failures. - - Returns: - ConversationDeleteResponse: Response indicating the result of the deletion operation. - """ - check_configuration_loaded(configuration) - - # Validate conversation ID format - if not check_suid(conversation_id): - logger.error("Invalid conversation ID format: %s", conversation_id) - response = BadRequestResponse( - resource="conversation", resource_id=conversation_id - ) - raise HTTPException(**response.model_dump()) - - user_id = auth[0] - if not can_access_conversation( - conversation_id, - user_id, - others_allowed=( - Action.DELETE_OTHERS_CONVERSATIONS in request.state.authorized_actions - ), - ): - logger.warning( - "User %s attempted to delete conversation %s they don't have access to", - user_id, - conversation_id, - ) - response = ForbiddenResponse.conversation( - action="delete", resource_id=conversation_id, user_id=user_id - ) - raise HTTPException(**response.model_dump()) - - # If reached this, user is authorized to retreive this conversation - conversation = retrieve_conversation(conversation_id) - if conversation is None: - response = NotFoundResponse( - resource="conversation", resource_id=conversation_id - ) - raise HTTPException(**response.model_dump()) - - agent_id = conversation_id - logger.info("Deleting conversation %s", conversation_id) - - try: - # Get Llama Stack client - client = AsyncLlamaStackClientHolder().get_client() - - agent_sessions = (await client.agents.session.list(agent_id=agent_id)).data - - if not agent_sessions: - # If no sessions are found, do not raise an error, just return a success response - logger.info("No sessions found for conversation %s", conversation_id) - return ConversationDeleteResponse( - deleted=False, - conversation_id=conversation_id, - ) - - session_id = str(agent_sessions[0].get("session_id")) - - await client.agents.session.delete(agent_id=agent_id, session_id=session_id) - - logger.info("Successfully deleted conversation %s", conversation_id) - - delete_conversation(conversation_id=conversation_id) - - return ConversationDeleteResponse( - deleted=True, - conversation_id=conversation_id, - ) - - except APIConnectionError as e: - response = ServiceUnavailableResponse(backend_name="Llama Stack", cause=str(e)) - raise HTTPException(**response.model_dump()) from e - - except NotFoundError as e: - response = NotFoundResponse( - resource="conversation", resource_id=conversation_id - ) - raise HTTPException(**response.model_dump()) from e - - except SQLAlchemyError as e: - logger.exception("Error deleting conversation %s: %s", conversation_id, e) - response = InternalServerErrorResponse.database_error() - raise HTTPException(**response.model_dump()) from e diff --git a/src/app/endpoints/conversations_v3.py b/src/app/endpoints/conversations_v1.py similarity index 84% rename from src/app/endpoints/conversations_v3.py rename to src/app/endpoints/conversations_v1.py index ff9f8058b..733b24f2a 100644 --- a/src/app/endpoints/conversations_v3.py +++ b/src/app/endpoints/conversations_v1.py @@ -3,7 +3,7 @@ import logging from typing import Any -from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi import APIRouter, Depends, HTTPException, Request from llama_stack_client import ( APIConnectionError, APIStatusError, @@ -16,7 +16,10 @@ from client import AsyncLlamaStackClientHolder from configuration import configuration from models.config import Action -from models.database.conversations import UserConversation +from models.database.conversations import ( + UserTurn, + UserConversation, +) from models.requests import ConversationUpdateRequest from models.responses import ( BadRequestResponse, @@ -42,6 +45,7 @@ normalize_conversation_id, to_llama_stack_conversation_id, ) +from utils.conversations import build_conversation_turns_from_items logger = logging.getLogger("app.endpoints.handlers") router = APIRouter(tags=["conversations_v1"]) @@ -84,7 +88,6 @@ 500: InternalServerErrorResponse.openapi_response( examples=["database", "configuration"] ), - 503: ServiceUnavailableResponse.openapi_response(), } conversation_update_responses: dict[int | str, dict[str, Any]] = { @@ -102,68 +105,6 @@ } -def simplify_conversation_items(items: list[dict]) -> list[dict[str, Any]]: - """Simplify conversation items to include only essential information. - - Args: - items: The full conversation items list from llama-stack Conversations API - (in reverse chronological order, newest first) - - Returns: - Simplified items with only essential message and tool call information - (in chronological order, oldest first, grouped by turns) - """ - # Filter only message type items - message_items = [item for item in items if item.get("type") == "message"] - - # Process from bottom up (reverse to get chronological order) - # Assume items are grouped correctly: user input followed by assistant output - reversed_messages = list(reversed(message_items)) - - chat_history = [] - i = 0 - while i < len(reversed_messages): - # Extract text content from user message - user_item = reversed_messages[i] - user_content = user_item.get("content", []) - user_text = "" - for content_part in user_content: - if isinstance(content_part, dict): - content_type = content_part.get("type") - if content_type == "input_text": - user_text += content_part.get("text", "") - elif isinstance(content_part, str): - user_text += content_part - - # Extract text content from assistant message (next item) - assistant_text = "" - if i + 1 < len(reversed_messages): - assistant_item = reversed_messages[i + 1] - assistant_content = assistant_item.get("content", []) - for content_part in assistant_content: - if isinstance(content_part, dict): - content_type = content_part.get("type") - if content_type == "output_text": - assistant_text += content_part.get("text", "") - elif isinstance(content_part, str): - assistant_text += content_part - - # Create turn with user message first, then assistant message - chat_history.append( - { - "messages": [ - {"content": user_text, "type": "user"}, - {"content": assistant_text, "type": "assistant"}, - ] - } - ) - - # Move to next pair (skip both user and assistant) - i += 2 - - return chat_history - - @router.get( "/conversations", responses=conversations_list_responses, @@ -231,7 +172,7 @@ async def get_conversations_list_endpoint_handler( summary="Conversation Get Endpoint Handler V1", ) @authorize(Action.GET_CONVERSATION) -async def get_conversation_endpoint_handler( +async def get_conversation_endpoint_handler( # pylint: disable=too-many-locals,too-many-statements request: Request, conversation_id: str, auth: Any = Depends(get_auth_dependency()), @@ -334,28 +275,47 @@ async def get_conversation_endpoint_handler( after=None, include=None, limit=None, - order=None, + order="asc", # oldest first ) - items = ( - conversation_items_response.data - if hasattr(conversation_items_response, "data") - else [] - ) - # Convert items to dict format for processing - items_dicts = [ - item.model_dump() if hasattr(item, "model_dump") else dict(item) - for item in items - ] + + if not conversation_items_response.data: + logger.error("No items found for conversation %s", conversation_id) + response = NotFoundResponse( + resource="conversation", resource_id=normalized_conv_id + ).model_dump() + raise HTTPException(**response) + + items = conversation_items_response.data logger.info( "Successfully retrieved %d items for conversation %s", - len(items_dicts), + len(items), conversation_id, ) - # Simplify the conversation items to include only essential information - chat_history = simplify_conversation_items(items_dicts) + # Retrieve turns metadata from database + db_turns: list[UserTurn] = [] + try: + with get_session() as session: + db_turns = ( + session.query(UserTurn) + .filter_by(conversation_id=normalized_conv_id) + .order_by(UserTurn.turn_number) + .all() + ) + except SQLAlchemyError as e: + logger.error( + "Database error occurred while retrieving conversation turns for %s.", + normalized_conv_id, + ) + response = InternalServerErrorResponse.database_error() + raise HTTPException(**response.model_dump()) from e + + # Build conversation turns from items and populate turns metadata + # Use conversation.created_at for legacy conversations without turn metadata + chat_history = build_conversation_turns_from_items( + items, db_turns, conversation.created_at + ) - # Conversations api has no support for message level timestamps return ConversationResponse( conversation_id=normalized_conv_id, chat_history=chat_history, @@ -472,12 +432,8 @@ async def delete_conversation_endpoint_handler( ) except APIConnectionError as e: - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail=ServiceUnavailableResponse( - backend_name="Llama Stack", cause=str(e) - ).model_dump(), - ) from e + response = ServiceUnavailableResponse(backend_name="Llama Stack", cause=str(e)) + raise HTTPException(**response.model_dump()) from e except APIStatusError: logger.warning( diff --git a/src/app/endpoints/conversations_v2.py b/src/app/endpoints/conversations_v2.py index a9125f74e..f9e6ebbc8 100644 --- a/src/app/endpoints/conversations_v2.py +++ b/src/app/endpoints/conversations_v2.py @@ -15,10 +15,12 @@ BadRequestResponse, ConversationDeleteResponse, ConversationResponse, + ConversationTurn, ConversationsListResponseV2, ConversationUpdateResponse, ForbiddenResponse, InternalServerErrorResponse, + Message, NotFoundResponse, UnauthorizedResponse, ) @@ -131,7 +133,10 @@ async def get_conversation_endpoint_handler( conversation = configuration.conversation_cache.get( user_id, conversation_id, skip_userid_check ) - chat_history = [transform_chat_message(entry) for entry in conversation] + # Each entry in conversation is a single turn + chat_history: list[ConversationTurn] = [ + build_conversation_turn_from_cache_entry(entry) for entry in conversation + ] return ConversationResponse( conversation_id=conversation_id, chat_history=chat_history @@ -238,21 +243,34 @@ def check_conversation_existence(user_id: str, conversation_id: str) -> None: raise HTTPException(**response.model_dump()) -def transform_chat_message(entry: CacheEntry) -> dict[str, Any]: - """Transform the message read from cache into format used by response payload.""" - user_message = {"content": entry.query, "type": "user"} - assistant_message: dict[str, Any] = {"content": entry.response, "type": "assistant"} - - # If referenced_documents exist on the entry, add them to the assistant message - if entry.referenced_documents is not None: - assistant_message["referenced_documents"] = [ - doc.model_dump(mode="json") for doc in entry.referenced_documents - ] - - return { - "provider": entry.provider, - "model": entry.model, - "messages": [user_message, assistant_message], - "started_at": entry.started_at, - "completed_at": entry.completed_at, - } +def build_conversation_turn_from_cache_entry(entry: CacheEntry) -> ConversationTurn: + """Build a ConversationTurn object from a cache entry. + + Each CacheEntry represents a single conversation turn with user query, + assistant response, and optional tool calls/results. + + Args: + entry: Cache entry representing one turn in the conversation + + Returns: + ConversationTurn object with messages, tool_calls, tool_results, and timestamps + """ + # Create Message objects for user and assistant + messages = [ + Message(content=entry.query, type="user"), + Message(content=entry.response, type="assistant"), + ] + + # Extract tool calls and results (default to empty lists if None) + tool_calls = entry.tool_calls if entry.tool_calls else [] + tool_results = entry.tool_results if entry.tool_results else [] + + return ConversationTurn( + messages=messages, + tool_calls=tool_calls, + tool_results=tool_results, + provider=entry.provider, + model=entry.model, + started_at=entry.started_at, + completed_at=entry.completed_at, + ) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index eddf30f86..75ad0f7e8 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -12,6 +12,7 @@ RateLimitError, # type: ignore ) from llama_stack_client.types.model_list_response import ModelListResponse +from sqlalchemy import func from sqlalchemy.exc import SQLAlchemyError import constants @@ -24,7 +25,7 @@ from configuration import configuration from models.cache_entry import CacheEntry from models.config import Action -from models.database.conversations import UserConversation +from models.database.conversations import UserConversation, UserTurn from models.requests import Attachment, QueryRequest from models.responses import ( ForbiddenResponse, @@ -86,7 +87,9 @@ def is_transcripts_enabled() -> bool: def persist_user_conversation_details( user_id: str, conversation_id: str, - model: str, + started_at: str, + completed_at: str, + model_id: str, provider_id: str, topic_summary: Optional[str], ) -> None: @@ -109,7 +112,7 @@ def persist_user_conversation_details( conversation = UserConversation( id=normalized_id, user_id=user_id, - last_used_model=model, + last_used_model=model_id, last_used_provider=provider_id, topic_summary=topic_summary, message_count=1, @@ -119,7 +122,7 @@ def persist_user_conversation_details( "Associated conversation %s to user %s", normalized_id, user_id ) else: - existing_conversation.last_used_model = model + existing_conversation.last_used_model = model_id existing_conversation.last_used_provider = provider_id existing_conversation.last_message_at = datetime.now(UTC) existing_conversation.message_count += 1 @@ -130,6 +133,34 @@ def persist_user_conversation_details( existing_conversation.message_count, ) + # Get the next turn number for this conversation + # Lock UserTurn rows for this conversation to prevent race conditions + # when computing max(turn_number) and inserting a new turn + session.query(UserTurn).filter_by( + conversation_id=normalized_id + ).with_for_update().all() + # Recompute max(turn_number) after acquiring the lock + max_turn_number = ( + session.query(func.max(UserTurn.turn_number)) + .filter_by(conversation_id=normalized_id) + .scalar() + ) + turn_number = (max_turn_number or 0) + 1 + turn = UserTurn( + conversation_id=normalized_id, + turn_number=turn_number, + started_at=datetime.fromisoformat(started_at), + completed_at=datetime.fromisoformat(completed_at), + provider=provider_id, + model=model_id, + ) + session.add(turn) + logger.debug( + "Created conversation turn - Conversation: %s, Turn: %d", + normalized_id, + turn_number, + ) + session.commit() logger.debug( "Successfully committed conversation %s to database", normalized_id @@ -313,6 +344,8 @@ async def query_endpoint_handler_base( # pylint: disable=R0914 "Topic summary generation disabled by request parameter" ) topic_summary = None + + completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ") # Convert RAG chunks to dictionary format once for reuse logger.info("Processing RAG chunks...") rag_chunks_dict = [chunk.model_dump() for chunk in summary.rag_chunks] @@ -338,12 +371,13 @@ async def query_endpoint_handler_base( # pylint: disable=R0914 persist_user_conversation_details( user_id=user_id, conversation_id=conversation_id, - model=model_id, + started_at=started_at, + completed_at=completed_at, + model_id=model_id, provider_id=provider_id, topic_summary=topic_summary, ) - completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ") cache_entry = CacheEntry( query=query_request.query, response=summary.llm_response, diff --git a/src/app/routers.py b/src/app/routers.py index b8e1d9af9..14b7c9dfb 100644 --- a/src/app/routers.py +++ b/src/app/routers.py @@ -15,7 +15,7 @@ streaming_query_v2, authorized, conversations_v2, - conversations_v3, + conversations_v1, metrics, tools, mcp_auth, @@ -54,8 +54,7 @@ def include_routers(app: FastAPI) -> None: app.include_router(streaming_query_v2.router, prefix="/v1") app.include_router(config.router, prefix="/v1") app.include_router(feedback.router, prefix="/v1") - # V1 conversations endpoint now uses V3 implementation (conversations is deprecated) - app.include_router(conversations_v3.router, prefix="/v1") + app.include_router(conversations_v1.router, prefix="/v1") app.include_router(conversations_v2.router, prefix="/v2") # Note: query_v2, streaming_query_v2, and conversations_v3 are now exposed at /v1 above diff --git a/src/models/database/conversations.py b/src/models/database/conversations.py index fd720b418..b34c9eb53 100644 --- a/src/models/database/conversations.py +++ b/src/models/database/conversations.py @@ -2,8 +2,8 @@ from datetime import datetime +from sqlalchemy import DateTime, ForeignKey, func from sqlalchemy.orm import Mapped, mapped_column -from sqlalchemy import DateTime, func from models.database.base import Base @@ -36,3 +36,33 @@ class UserConversation(Base): # pylint: disable=too-few-public-methods message_count: Mapped[int] = mapped_column(default=0) topic_summary: Mapped[str] = mapped_column(default="") + + +class UserTurn(Base): # pylint: disable=too-few-public-methods + """Model for storing turn-level metadata.""" + + __tablename__ = "user_turn" + + # Foreign key to user_conversation (part of composite primary key) + conversation_id: Mapped[str] = mapped_column( + ForeignKey("user_conversation.id", ondelete="CASCADE"), + primary_key=True, + ) + + # Turn number (1-indexed, first turn is 1) for ordering within a conversation + # Part of composite primary key with conversation_id + turn_number: Mapped[int] = mapped_column(primary_key=True) + + # Timestamps for the turn + started_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + nullable=False, + ) + completed_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + nullable=False, + ) + + provider: Mapped[str] = mapped_column(nullable=False) + + model: Mapped[str] = mapped_column(nullable=False) diff --git a/src/models/responses.py b/src/models/responses.py index 214bb47dc..7e66e27a7 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -2,7 +2,7 @@ """Models for REST API responses.""" -from typing import Any, ClassVar, Optional, Union +from typing import Any, ClassVar, Literal, Optional, Union from fastapi import status from pydantic import AnyUrl, BaseModel, Field @@ -835,29 +835,79 @@ class AuthorizedResponse(AbstractSuccessfulResponse): } +class Message(BaseModel): + """Model representing a message in a conversation turn. + + Attributes: + content: The message content. + type: The type of message. + """ + + content: str = Field( + ..., + description="The message content", + examples=["Hello, how can I help you?"], + ) + type: Literal["user", "assistant", "system", "developer"] = Field( + ..., + description="The type of message", + examples=["user", "assistant", "system", "developer"], + ) + + +class ConversationTurn(BaseModel): + """Model representing a single conversation turn. + + Attributes: + messages: List of messages in this turn. + tool_calls: List of tool calls made in this turn. + tool_results: List of tool results from this turn. + provider: Provider identifier used for this turn. + model: Model identifier used for this turn. + started_at: ISO 8601 timestamp when the turn started. + completed_at: ISO 8601 timestamp when the turn completed. + """ + + messages: list[Message] = Field( + default_factory=list, + description="List of messages in this turn", + ) + tool_calls: list[ToolCallSummary] = Field( + default_factory=list, + description="List of tool calls made in this turn", + ) + tool_results: list[ToolResultSummary] = Field( + default_factory=list, + description="List of tool results from this turn", + ) + provider: str = Field( + ..., + description="Provider identifier used for this turn", + examples=["openai"], + ) + model: str = Field( + ..., + description="Model identifier used for this turn", + examples=["gpt-4o-mini"], + ) + started_at: str = Field( + ..., + description="ISO 8601 timestamp when the turn started", + examples=["2024-01-01T00:01:00Z"], + ) + completed_at: str = Field( + ..., + description="ISO 8601 timestamp when the turn completed", + examples=["2024-01-01T00:01:05Z"], + ) + + class ConversationResponse(AbstractSuccessfulResponse): """Model representing a response for retrieving a conversation. Attributes: conversation_id: The conversation ID (UUID). - chat_history: The simplified chat history as a list of conversation turns. - - Example: - ```python - conversation_response = ConversationResponse( - conversation_id="123e4567-e89b-12d3-a456-426614174000", - chat_history=[ - { - "messages": [ - {"content": "Hello", "type": "user"}, - {"content": "Hi there!", "type": "assistant"} - ], - "started_at": "2024-01-01T00:01:00Z", - "completed_at": "2024-01-01T00:01:05Z" - } - ] - ) - ``` + chat_history: The chat history as a list of conversation turns. """ conversation_id: str = Field( @@ -866,7 +916,7 @@ class ConversationResponse(AbstractSuccessfulResponse): examples=["c5260aec-4d82-4370-9fdf-05cf908b3f16"], ) - chat_history: list[dict[str, Any]] = Field( + chat_history: list[ConversationTurn] = Field( ..., description="The simplified chat history as a list of conversation turns", examples=[ @@ -875,6 +925,10 @@ class ConversationResponse(AbstractSuccessfulResponse): {"content": "Hello", "type": "user"}, {"content": "Hi there!", "type": "assistant"}, ], + "tool_calls": [], + "tool_results": [], + "provider": "openai", + "model": "gpt-4o-mini", "started_at": "2024-01-01T00:01:00Z", "completed_at": "2024-01-01T00:01:05Z", } @@ -893,6 +947,10 @@ class ConversationResponse(AbstractSuccessfulResponse): {"content": "Hello", "type": "user"}, {"content": "Hi there!", "type": "assistant"}, ], + "tool_calls": [], + "tool_results": [], + "provider": "openai", + "model": "gpt-4o-mini", "started_at": "2024-01-01T00:01:00Z", "completed_at": "2024-01-01T00:01:05Z", } diff --git a/src/utils/conversations.py b/src/utils/conversations.py new file mode 100644 index 000000000..83c485a3e --- /dev/null +++ b/src/utils/conversations.py @@ -0,0 +1,382 @@ +"""Utilities for conversations.""" + +import json +from datetime import UTC, datetime +from typing import Any, Optional, Union, cast + +from llama_stack_api.openai_responses import ( + OpenAIResponseOutputMessageFileSearchToolCall, + OpenAIResponseOutputMessageFunctionToolCall, + OpenAIResponseOutputMessageMCPCall, + OpenAIResponseOutputMessageMCPListTools, + OpenAIResponseOutputMessageWebSearchToolCall, +) +from llama_stack_client.types.conversations.item_list_response import ( + ItemListResponse, + OpenAIResponseInputFunctionToolCallOutput, + OpenAIResponseMcpApprovalRequest, + OpenAIResponseMcpApprovalResponse, + OpenAIResponseMessageOutput, +) + +from constants import DEFAULT_RAG_TOOL +from models.database.conversations import UserTurn +from models.responses import ConversationTurn, Message +from utils.query import parse_arguments_string +from utils.types import ToolCallSummary, ToolResultSummary + + +def _extract_text_from_content(content: Union[str, list[Any]]) -> str: + """Extract text content from message content. + + Args: + content: The content field from a message (can be str or list) + + Returns: + Extracted text content as a string + """ + if isinstance(content, str): + return content + + text_fragments: list[str] = [] + if isinstance(content, list): + for part in content: + if isinstance(part, str): + text_fragments.append(part) + continue + text_value = getattr(part, "text", None) + if text_value: + text_fragments.append(text_value) + continue + refusal = getattr(part, "refusal", None) + if refusal: + text_fragments.append(refusal) + continue + if isinstance(part, dict): + dict_text = part.get("text") or part.get("refusal") + if dict_text: + text_fragments.append(str(dict_text)) + + return "".join(text_fragments) + + +def _parse_message_item(item: OpenAIResponseMessageOutput) -> Message: + """Parse a message item into a Message object. + + Args: + item: The message item from Conversations API + + Returns: + Message object with extracted content and type (user or assistant) + """ + content_text = _extract_text_from_content(item.content) + message_type = item.role + return Message(content=content_text, type=message_type) + + +def _build_tool_call_summary_from_item( # pylint: disable=too-many-return-statements + item: ItemListResponse, +) -> tuple[Optional[ToolCallSummary], Optional[ToolResultSummary]]: + """Translate Conversations API tool items into ToolCallSummary and ToolResultSummary records. + + Args: + item: A tool item from the Conversations API items list + + Returns: + A tuple of (ToolCallSummary, ToolResultSummary) one of them possibly None + if the item type doesn't provide both call and result information. + """ + item_type = getattr(item, "type", None) + + if item_type == "function_call": + function_call_item = cast(OpenAIResponseOutputMessageFunctionToolCall, item) + return ( + ToolCallSummary( + id=function_call_item.call_id, + name=function_call_item.name, + args=parse_arguments_string(function_call_item.arguments), + type="function_call", + ), + None, # Function call results come as separate function_call_output items + ) + + if item_type == "file_search_call": + file_search_item = cast(OpenAIResponseOutputMessageFileSearchToolCall, item) + response_payload: Optional[dict[str, Any]] = None + if file_search_item.results is not None: + response_payload = { + "results": [result.model_dump() for result in file_search_item.results] + } + return ( + ToolCallSummary( + id=file_search_item.id, + name=DEFAULT_RAG_TOOL, + args={"queries": file_search_item.queries}, + type="file_search_call", + ), + ToolResultSummary( + id=file_search_item.id, + status=file_search_item.status, + content=json.dumps(response_payload) if response_payload else "", + type="file_search_call", + round=1, + ), + ) + + if item_type == "web_search_call": + web_search_item = cast(OpenAIResponseOutputMessageWebSearchToolCall, item) + return ( + ToolCallSummary( + id=web_search_item.id, + name="web_search", + args={}, + type="web_search_call", + ), + ToolResultSummary( + id=web_search_item.id, + status=web_search_item.status, + content="", + type="web_search_call", + round=1, + ), + ) + + if item_type == "mcp_call": + mcp_call_item = cast(OpenAIResponseOutputMessageMCPCall, item) + args = parse_arguments_string(mcp_call_item.arguments) + if mcp_call_item.server_label: + args["server_label"] = mcp_call_item.server_label + content = ( + mcp_call_item.error + if mcp_call_item.error + else (mcp_call_item.output if mcp_call_item.output else "") + ) + + return ( + ToolCallSummary( + id=mcp_call_item.id, + name=mcp_call_item.name, + args=args, + type="mcp_call", + ), + ToolResultSummary( + id=mcp_call_item.id, + status="success" if mcp_call_item.error is None else "failure", + content=content, + type="mcp_call", + round=1, + ), + ) + + if item_type == "mcp_list_tools": + mcp_list_tools_item = cast(OpenAIResponseOutputMessageMCPListTools, item) + tools_info = [ + { + "name": tool.name, + "description": tool.description, + "input_schema": tool.input_schema, + } + for tool in mcp_list_tools_item.tools + ] + content_dict = { + "server_label": mcp_list_tools_item.server_label, + "tools": tools_info, + } + return ( + ToolCallSummary( + id=mcp_list_tools_item.id, + name="mcp_list_tools", + args={"server_label": mcp_list_tools_item.server_label}, + type="mcp_list_tools", + ), + ToolResultSummary( + id=mcp_list_tools_item.id, + status="success", + content=json.dumps(content_dict), + type="mcp_list_tools", + round=1, + ), + ) + + if item_type == "mcp_approval_request": + approval_request_item = cast(OpenAIResponseMcpApprovalRequest, item) + args = parse_arguments_string(approval_request_item.arguments) + return ( + ToolCallSummary( + id=approval_request_item.id, + name=approval_request_item.name, + args=args, + type="tool_call", + ), + None, + ) + + if item_type == "mcp_approval_response": + approval_response_item = cast(OpenAIResponseMcpApprovalResponse, item) + content_dict = {} + if approval_response_item.reason: + content_dict["reason"] = approval_response_item.reason + return ( + None, + ToolResultSummary( + id=approval_response_item.approval_request_id, + status="success" if approval_response_item.approve else "denied", + content=json.dumps(content_dict), + type="mcp_approval_response", + round=1, + ), + ) + + if item_type == "function_call_output": + function_output = cast(OpenAIResponseInputFunctionToolCallOutput, item) + return ( + None, + ToolResultSummary( + id=function_output.call_id, + status=function_output.status or "success", + content=function_output.output, + type="function_call_output", + round=1, + ), + ) + + return None, None + + +def _create_dummy_turn_metadata(started_at: datetime) -> UserTurn: + """Create a dummy UserTurn instance for legacy conversations without metadata. + + Args: + started_at: Timestamp to use for started_at and completed_at (conversation created_at) + + Returns: + UserTurn instance with default values (N/A for provider/model, provided timestamp) + for legacy conversations that don't have stored turn metadata. + """ + # Create a UserTurn instance with default values for legacy conversations + # Note: conversation_id and turn_number are not used, so we use placeholder values + return UserTurn( + conversation_id="", + turn_number=0, + started_at=started_at, + completed_at=started_at, + provider="N/A", + model="N/A", + ) + + +def _create_turn_from_db_metadata( + turn_metadata: UserTurn, + messages: list[Message], + tool_calls: list[ToolCallSummary], + tool_results: list[ToolResultSummary], +) -> ConversationTurn: + """Create a ConversationTurn from database metadata and accumulated items. + + Args: + turn_metadata: Database UserTurn object with metadata + messages: List of messages for this turn + tool_calls: List of tool calls for this turn + tool_results: List of tool results for this turn + + Returns: + ConversationTurn object with all metadata populated + """ + started_at = turn_metadata.started_at.astimezone(UTC).strftime("%Y-%m-%dT%H:%M:%SZ") + completed_at = turn_metadata.completed_at.astimezone(UTC).strftime( + "%Y-%m-%dT%H:%M:%SZ" + ) + return ConversationTurn( + messages=messages, + tool_calls=tool_calls, + tool_results=tool_results, + provider=turn_metadata.provider, + model=turn_metadata.model, + started_at=started_at, + completed_at=completed_at, + ) + + +def build_conversation_turns_from_items( + items: list[ItemListResponse], + turns_metadata: list[UserTurn], + conversation_start_time: datetime, +) -> list[ConversationTurn]: + """Build conversation turns from Conversations API items and turns metadata. + + Args: + items: Conversation items list from Conversations API, oldest first + turns_metadata: List of UserTurn database objects ordered by turn_number. + Can be empty for legacy conversations without stored metadata. + conversation_start_time: Timestamp to use for dummy metadata in legacy conversations. + Typically the conversation's created_at timestamp. + + Returns: + List of ConversationTurn objects, oldest first + """ + chat_history: list[ConversationTurn] = [] + current_messages: list[Message] = [] + current_tool_calls: list[ToolCallSummary] = [] + current_tool_results: list[ToolResultSummary] = [] + current_turn_index = 0 + + for item in items: + item_type = getattr(item, "type", None) + + # Parse message items + if item_type == "message": + message_item = cast(OpenAIResponseMessageOutput, item) + message = _parse_message_item(message_item) + + # User message marks the beginning of a new turn + if message.type == "user": + # If we have accumulated items, finish the previous turn + if current_messages or current_tool_calls or current_tool_results: + turn_metadata = ( + turns_metadata[current_turn_index] + if current_turn_index < len(turns_metadata) + else _create_dummy_turn_metadata(conversation_start_time) + ) + chat_history.append( + _create_turn_from_db_metadata( + turn_metadata, + current_messages, + current_tool_calls, + current_tool_results, + ) + ) + current_turn_index += 1 + + # Start new turn with this user message + current_messages = [message] + current_tool_calls = [] + current_tool_results = [] + else: + # Add non-user message to current turn + current_messages.append(message) + + # Parse tool-related items + else: + tool_call, tool_result = _build_tool_call_summary_from_item(item) + if tool_call is not None: + current_tool_calls.append(tool_call) + if tool_result is not None: + current_tool_results.append(tool_result) + + # Add final turn if there are items + if current_messages or current_tool_calls or current_tool_results: + turn_metadata = ( + turns_metadata[current_turn_index] + if current_turn_index < len(turns_metadata) + else _create_dummy_turn_metadata(conversation_start_time) + ) + chat_history.append( + _create_turn_from_db_metadata( + turn_metadata, + current_messages, + current_tool_calls, + current_tool_results, + ) + ) + + return chat_history diff --git a/src/utils/endpoints.py b/src/utils/endpoints.py index b0b49917d..4c555fab3 100644 --- a/src/utils/endpoints.py +++ b/src/utils/endpoints.py @@ -822,7 +822,9 @@ async def cleanup_after_streaming( persist_user_conversation_details_func( user_id=user_id, conversation_id=conversation_id, - model=model_id, + started_at=started_at, + completed_at=completed_at, + model_id=model_id, provider_id=provider_id, topic_summary=topic_summary, ) diff --git a/tests/e2e/features/conversation_cache_v2.feature b/tests/e2e/features/conversation_cache_v2.feature index efc0ba601..51e2be687 100644 --- a/tests/e2e/features/conversation_cache_v2.feature +++ b/tests/e2e/features/conversation_cache_v2.feature @@ -130,10 +130,18 @@ Feature: Conversation Cache V2 API tests } } }, + "tool_calls": { + "type": "array", + "items": { "type": "object" } + }, + "tool_results": { + "type": "array", + "items": { "type": "object" } + }, "started_at": { "type": "string", "format": "date-time" }, "completed_at": { "type": "string", "format": "date-time" } }, - "required": ["provider", "model", "messages", "started_at", "completed_at"] + "required": ["provider", "model", "messages", "tool_calls", "tool_results", "started_at", "completed_at"] } } } diff --git a/tests/e2e/features/conversations.feature b/tests/e2e/features/conversations.feature index a3f04078b..1d7671f29 100644 --- a/tests/e2e/features/conversations.feature +++ b/tests/e2e/features/conversations.feature @@ -73,6 +73,8 @@ Feature: conversations endpoint API tests "items": { "type": "object", "properties": { + "provider": { "type": "string" }, + "model": { "type": "string" }, "messages": { "type": "array", "items": { @@ -83,9 +85,18 @@ Feature: conversations endpoint API tests } } }, + "tool_calls": { + "type": "array", + "items": { "type": "object" } + }, + "tool_results": { + "type": "array", + "items": { "type": "object" } + }, "started_at": { "type": "string", "format": "date-time" }, "completed_at": { "type": "string", "format": "date-time" } - } + }, + "required": ["provider", "model", "messages", "tool_calls", "tool_results", "started_at", "completed_at"] } } } diff --git a/tests/integration/test_openapi_json.py b/tests/integration/test_openapi_json.py index 0e098f940..53dbacc02 100644 --- a/tests/integration/test_openapi_json.py +++ b/tests/integration/test_openapi_json.py @@ -227,7 +227,7 @@ def test_servers_section_present_from_url(spec_from_url: dict[str, Any]) -> None ("/v1/feedback", "post", {"200", "401", "403", "404", "500"}), ("/v1/feedback/status", "get", {"200"}), ("/v1/feedback/status", "put", {"200", "401", "403", "500"}), - ("/v1/conversations", "get", {"200", "401", "403", "500", "503"}), + ("/v1/conversations", "get", {"200", "401", "403", "500"}), ( "/v1/conversations/{conversation_id}", "get", @@ -309,7 +309,7 @@ def test_paths_and_responses_exist_from_file( ("/v1/feedback", "post", {"200", "401", "403", "404", "500"}), ("/v1/feedback/status", "get", {"200"}), ("/v1/feedback/status", "put", {"200", "401", "403", "500"}), - ("/v1/conversations", "get", {"200", "401", "403", "500", "503"}), + ("/v1/conversations", "get", {"200", "401", "403", "500"}), ( "/v1/conversations/{conversation_id}", "get", diff --git a/tests/unit/app/endpoints/test_conversations.py b/tests/unit/app/endpoints/test_conversations.py index 7e12da21a..b8ec80d45 100644 --- a/tests/unit/app/endpoints/test_conversations.py +++ b/tests/unit/app/endpoints/test_conversations.py @@ -3,27 +3,31 @@ """Unit tests for the /conversations REST API endpoints.""" +from datetime import UTC, datetime from typing import Any, Optional import pytest from fastapi import HTTPException, Request, status -from llama_stack_client import APIConnectionError, NotFoundError +from llama_stack_client import APIConnectionError, APIStatusError, NotFoundError from pytest_mock import MockerFixture, MockType from sqlalchemy.exc import SQLAlchemyError -from app.endpoints.conversations import ( +from app.endpoints.conversations_v1 import ( delete_conversation_endpoint_handler, get_conversation_endpoint_handler, get_conversations_list_endpoint_handler, - simplify_session_data, + update_conversation_endpoint_handler, ) +from utils.conversations import build_conversation_turns_from_items from configuration import AppConfig from models.config import Action -from models.database.conversations import UserConversation +from models.database.conversations import UserConversation, UserTurn +from models.requests import ConversationUpdateRequest from models.responses import ( ConversationDeleteResponse, ConversationResponse, ConversationsListResponse, + ConversationUpdateResponse, ) from tests.unit.utils.auth_helpers import mock_authorization_resolvers @@ -104,8 +108,41 @@ def create_mock_conversation( return mock_conversation +def create_mock_db_turn( + mocker: MockerFixture, + turn_number: int, + started_at: str = "2024-01-01T00:01:00Z", + completed_at: str = "2024-01-01T00:01:05Z", + provider: str = "google", + model: str = "gemini-2.0-flash-exp", +) -> MockType: + """Create a mock UserTurn database object. + + Args: + mocker: Mocker fixture + turn_number: Turn number (1-indexed) + started_at: ISO 8601 timestamp string + completed_at: ISO 8601 timestamp string + provider: Provider identifier + model: Model identifier + + Returns: + Mock UserTurn database object with required attributes + """ + mock_turn = mocker.Mock(spec=UserTurn) + mock_turn.turn_number = turn_number + # Convert ISO strings to datetime objects (Python 3.12+ supports "Z" directly) + mock_turn.started_at = datetime.fromisoformat(started_at) + mock_turn.completed_at = datetime.fromisoformat(completed_at) + mock_turn.provider = provider + mock_turn.model = model + return mock_turn + + def mock_database_session( - mocker: MockerFixture, query_result: Optional[list[MockType]] = None + mocker: MockerFixture, + query_result: Optional[list[MockType]] = None, + db_turns: Optional[list[MockType]] = None, ) -> MockType: """Helper function to mock get_session with proper context manager support. @@ -115,26 +152,47 @@ def mock_database_session( mocker (pytest.MockerFixture): Fixture used to create and patch mocks. query_result (Optional[list]): If provided, configures the session.query().all() and session.query().filter_by().all() to return - this list. + this list (for UserConversation queries). + db_turns (Optional[list]): If provided, configures UserTurn queries + to return this list. Returns: Mock: The mocked session object that will be yielded by the patched get_session context manager. """ mock_session = mocker.Mock() - if query_result is not None: - # Mock both the filtered and unfiltered query paths + + def query_side_effect(model_class: type[Any]) -> Any: + """Handle different model queries.""" mock_query = mocker.Mock() - mock_query.all.return_value = query_result - mock_query.filter_by.return_value.all.return_value = query_result - mock_session.query.return_value = mock_query + if model_class == UserTurn: + # For UserTurn queries + if db_turns is not None: + mock_query.filter_by.return_value.order_by.return_value.all.return_value = ( + db_turns + ) + else: + mock_query.filter_by.return_value.order_by.return_value.all.return_value = ( + [] + ) + else: + # For other queries (UserConversation, etc.) + if query_result is not None: + mock_query.all.return_value = query_result + mock_query.filter_by.return_value.all.return_value = query_result + mock_query.filter_by.return_value.first.return_value = ( + query_result[0] if query_result else None + ) + return mock_query + + mock_session.query.side_effect = query_side_effect # Mock get_session to return a context manager mock_session_context = mocker.MagicMock() mock_session_context.__enter__.return_value = mock_session mock_session_context.__exit__.return_value = None mocker.patch( - "app.endpoints.conversations.get_session", return_value=mock_session_context + "app.endpoints.conversations_v1.get_session", return_value=mock_session_context ) return mock_session @@ -251,6 +309,8 @@ def expected_chat_history_fixture() -> list[dict[str, Any]]: list[dict[str, Any]]: A list of conversation turns. Each turn contains: - messages: list of message dicts with `content` (str) and `type` (`"user"` or `"assistant"`) + - tool_calls: list of tool call summaries (empty by default) + - tool_results: list of tool result summaries (empty by default) - started_at: ISO 8601 UTC timestamp string for the turn start - completed_at: ISO 8601 UTC timestamp string for the turn end """ @@ -260,6 +320,10 @@ def expected_chat_history_fixture() -> list[dict[str, Any]]: {"content": "Hello", "type": "user"}, {"content": "Hi there!", "type": "assistant"}, ], + "tool_calls": [], + "tool_results": [], + "provider": "google", + "model": "gemini-2.0-flash-exp", "started_at": "2024-01-01T00:01:00Z", "completed_at": "2024-01-01T00:01:05Z", }, @@ -268,6 +332,10 @@ def expected_chat_history_fixture() -> list[dict[str, Any]]: {"content": "How are you?", "type": "user"}, {"content": "I'm doing well, thanks!", "type": "assistant"}, ], + "tool_calls": [], + "tool_results": [], + "provider": "google", + "model": "gemini-2.0-flash-exp", "started_at": "2024-01-01T00:02:00Z", "completed_at": "2024-01-01T00:02:03Z", }, @@ -294,77 +362,54 @@ def mock_conversation_fixture() -> UserConversation: return mock_conv -class TestSimplifySessionData: - """Test cases for the simplify_session_data function.""" +class TestBuildConversationTurnsFromItems: + """Test cases for the build_conversation_turns_from_items function.""" @pytest.mark.asyncio - async def test_simplify_session_data_with_model_dump( + async def test_build_conversation_turns_from_items_with_model_dump( self, - mock_session_data: dict[str, Any], + mocker: MockerFixture, + mock_session_data: dict[str, Any], # pylint: disable=unused-argument expected_chat_history: list[dict[str, Any]], ) -> None: - """Test simplify_session_data with session data.""" - result = simplify_session_data(mock_session_data) - - assert result == expected_chat_history + """Test build_conversation_turns_from_items with items data.""" + # Create mock items from session_data structure + mock_items = [ + mocker.Mock(type="message", role="user", content="Hello"), + mocker.Mock(type="message", role="assistant", content="Hi there!"), + mocker.Mock(type="message", role="user", content="How are you?"), + mocker.Mock( + type="message", role="assistant", content="I'm doing well, thanks!" + ), + ] + # Create mock db_turns matching the expected turns + mock_db_turns = [ + create_mock_db_turn( + mocker, 1, "2024-01-01T00:01:00Z", "2024-01-01T00:01:05Z" + ), + create_mock_db_turn( + mocker, 2, "2024-01-01T00:02:00Z", "2024-01-01T00:02:03Z" + ), + ] + conversation_start_time = datetime.fromisoformat( + "2024-01-01T00:00:00Z" + ).replace(tzinfo=UTC) + result = build_conversation_turns_from_items( + mock_items, mock_db_turns, conversation_start_time + ) + actual_history = [turn.model_dump() for turn in result] + assert actual_history == expected_chat_history @pytest.mark.asyncio - async def test_simplify_session_data_empty_turns(self) -> None: - """Test simplify_session_data with empty turns.""" - session_data = { - "session_id": VALID_CONVERSATION_ID, - "started_at": "2024-01-01T00:00:00Z", - "turns": [], - } - - result = simplify_session_data(session_data) + async def test_build_conversation_turns_from_items_empty_turns(self) -> None: + """Test build_conversation_turns_from_items with empty items.""" + conversation_start_time = datetime.fromisoformat( + "2024-01-01T00:00:00Z" + ).replace(tzinfo=UTC) + result = build_conversation_turns_from_items([], [], conversation_start_time) assert not result - @pytest.mark.asyncio - async def test_simplify_session_data_filters_unwanted_fields(self) -> None: - """Test that simplify_session_data properly filters out unwanted fields.""" - session_data = { - "session_id": VALID_CONVERSATION_ID, - "turns": [ - { - "turn_id": "turn-1", - "input_messages": [ - { - "content": "Test message", - "role": "user", - "context": {"some": "context"}, # Should be filtered out - "metadata": {"extra": "data"}, # Should be filtered out - } - ], - "output_message": { - "content": "Test response", - "role": "assistant", - "stop_reason": "end_of_turn", # Should be filtered out - "tool_calls": ["tool1", "tool2"], # Should be filtered out - }, - "started_at": "2024-01-01T00:01:00Z", - "completed_at": "2024-01-01T00:01:05Z", - "steps": ["step1", "step2"], # Should be filtered out - } - ], - } - - result = simplify_session_data(session_data) - - expected = [ - { - "messages": [ - {"content": "Test message", "type": "user"}, - {"content": "Test response", "type": "assistant"}, - ], - "started_at": "2024-01-01T00:01:00Z", - "completed_at": "2024-01-01T00:01:05Z", - } - ] - - assert result == expected - class TestGetConversationEndpoint: """Test cases for the GET /conversations/{conversation_id} endpoint.""" @@ -376,7 +421,7 @@ async def test_configuration_not_loaded( """Test the endpoint when configuration is not loaded.""" mock_authorization_resolvers(mocker) mock_config = AppConfig() - mocker.patch("app.endpoints.conversations.configuration", mock_config) + mocker.patch("app.endpoints.conversations_v1.configuration", mock_config) with pytest.raises(HTTPException) as exc_info: await get_conversation_endpoint_handler( @@ -400,8 +445,10 @@ async def test_invalid_conversation_id_format( ) -> None: """Test the endpoint with an invalid conversation ID format.""" mock_authorization_resolvers(mocker) - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) - mocker.patch("app.endpoints.conversations.check_suid", return_value=False) + mocker.patch( + "app.endpoints.conversations_v1.configuration", setup_configuration + ) + mocker.patch("app.endpoints.conversations_v1.check_suid", return_value=False) with pytest.raises(HTTPException) as exc_info: await get_conversation_endpoint_handler( @@ -425,18 +472,23 @@ async def test_llama_stack_connection_error( ) -> None: """Test the endpoint when LlamaStack connection fails.""" mock_authorization_resolvers(mocker) - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) - mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - mocker.patch("app.endpoints.conversations.can_access_conversation") - mocker.patch("app.endpoints.conversations.retrieve_conversation") + mocker.patch( + "app.endpoints.conversations_v1.configuration", setup_configuration + ) + mocker.patch("app.endpoints.conversations_v1.check_suid", return_value=True) + mocker.patch("app.endpoints.conversations_v1.can_access_conversation") + mocker.patch("app.endpoints.conversations_v1.retrieve_conversation") + + # Mock database session (empty db_turns since API call fails before query) + mock_database_session(mocker, db_turns=[]) # Mock AsyncLlamaStackClientHolder to raise APIConnectionError mock_client = mocker.AsyncMock() - mock_client.agents.session.list.side_effect = APIConnectionError( + mock_client.conversations.items.list.side_effect = APIConnectionError( request=None # type: ignore ) mock_client_holder = mocker.patch( - "app.endpoints.conversations.AsyncLlamaStackClientHolder" + "app.endpoints.conversations_v1.AsyncLlamaStackClientHolder" ) mock_client_holder.return_value.get_client.return_value = mock_client @@ -471,18 +523,25 @@ async def test_llama_stack_not_found_error( "does not exist" and the conversation ID. """ mock_authorization_resolvers(mocker) - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) - mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - mocker.patch("app.endpoints.conversations.can_access_conversation") - mocker.patch("app.endpoints.conversations.retrieve_conversation") + mocker.patch( + "app.endpoints.conversations_v1.configuration", setup_configuration + ) + mocker.patch("app.endpoints.conversations_v1.check_suid", return_value=True) + mocker.patch("app.endpoints.conversations_v1.can_access_conversation") + mocker.patch("app.endpoints.conversations_v1.retrieve_conversation") + + # Mock database session (empty db_turns since API call fails before query) + mock_database_session(mocker, db_turns=[]) # Mock AsyncLlamaStackClientHolder to raise NotFoundError mock_client = mocker.AsyncMock() - mock_client.agents.session.list.side_effect = NotFoundError( - message="Session not found", response=mocker.Mock(request=None), body=None + mock_client.conversations.items.list.side_effect = NotFoundError( + message="Conversation not found", + response=mocker.Mock(request=None), + body=None, ) mock_client_holder = mocker.patch( - "app.endpoints.conversations.AsyncLlamaStackClientHolder" + "app.endpoints.conversations_v1.AsyncLlamaStackClientHolder" ) mock_client_holder.return_value.get_client.return_value = mock_client @@ -501,42 +560,6 @@ async def test_llama_stack_not_found_error( assert "does not exist" in detail["cause"] # type: ignore assert VALID_CONVERSATION_ID in detail["cause"] # type: ignore - @pytest.mark.asyncio - async def test_session_retrieve_exception( - self, - mocker: MockerFixture, - setup_configuration: AppConfig, - dummy_request: Request, - ) -> None: - """Test the endpoint when session retrieval raises an APIConnectionError.""" - mock_authorization_resolvers(mocker) - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) - mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - mocker.patch("app.endpoints.conversations.can_access_conversation") - mocker.patch("app.endpoints.conversations.retrieve_conversation") - - # Mock AsyncLlamaStackClientHolder to raise APIConnectionError - mock_client_holder = mocker.patch( - "app.endpoints.conversations.AsyncLlamaStackClientHolder" - ) - mock_client = mocker.AsyncMock() - mock_client.agents.session.list.side_effect = APIConnectionError( - request=mocker.Mock() - ) - mock_client_holder.return_value.get_client.return_value = mock_client - - with pytest.raises(HTTPException) as exc_info: - await get_conversation_endpoint_handler( - request=dummy_request, - conversation_id=VALID_CONVERSATION_ID, - auth=MOCK_AUTH, - ) - - assert exc_info.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE - detail = exc_info.value.detail - assert isinstance(detail, dict) - assert detail["response"] == "Unable to connect to Llama Stack" # type: ignore - @pytest.mark.asyncio async def test_get_conversation_forbidden( self, @@ -546,10 +569,12 @@ async def test_get_conversation_forbidden( mock_conversation: MockType, ) -> None: """Test forbidden access when user lacks permission to read conversation.""" - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) - mocker.patch("app.endpoints.conversations.check_suid", return_value=True) mocker.patch( - "app.endpoints.conversations.retrieve_conversation", + "app.endpoints.conversations_v1.configuration", setup_configuration + ) + mocker.patch("app.endpoints.conversations_v1.check_suid", return_value=True) + mocker.patch( + "app.endpoints.conversations_v1.retrieve_conversation", return_value=mock_conversation, ) mocker.patch( @@ -594,31 +619,46 @@ async def test_get_others_conversations_allowed_for_authorized_user( setup_configuration: AppConfig, mock_conversation: MockType, dummy_request: Request, - mock_session_data: dict[str, Any], - ) -> None: # pylint: disable=too-many-arguments, too-many-positional-arguments + ) -> None: """Test allowed access to another user's conversation for authorized user.""" mocker.patch( "authorization.resolvers.NoopAccessResolver.get_actions", return_value={Action.GET_CONVERSATION, Action.READ_OTHERS_CONVERSATIONS}, ) # Allow user to access other users' conversations - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) - mocker.patch("app.endpoints.conversations.check_suid", return_value=True) mocker.patch( - "app.endpoints.conversations.retrieve_conversation", + "app.endpoints.conversations_v1.configuration", setup_configuration + ) + mocker.patch("app.endpoints.conversations_v1.check_suid", return_value=True) + mocker.patch( + "app.endpoints.conversations_v1.retrieve_conversation", return_value=mock_conversation, ) - mock_client = mocker.AsyncMock() - mock_client.agents.session.list.return_value = mocker.Mock( - data=[mock_session_data] - ) + # Mock UserTurn database queries + mock_db_turns = [ + create_mock_db_turn( + mocker, 1, "2024-01-01T00:01:00Z", "2024-01-01T00:01:05Z" + ), + ] + mock_database_session(mocker, db_turns=mock_db_turns) - mock_session_retrieve_result = mocker.Mock() - mock_session_retrieve_result.model_dump.return_value = mock_session_data - mock_client.agents.session.retrieve.return_value = mock_session_retrieve_result + # Mock Conversations API - conversations_v1 uses conversations.items.list + mock_client = mocker.AsyncMock() + mock_items_response = mocker.Mock() + # Create mock items that match mock_session_data structure + mock_item1 = mocker.Mock() + mock_item1.type = "message" + mock_item1.role = "user" + mock_item1.content = "Hello" + mock_item2 = mocker.Mock() + mock_item2.type = "message" + mock_item2.role = "assistant" + mock_item2.content = "Hi there!" + mock_items_response.data = [mock_item1, mock_item2] + mock_client.conversations.items.list.return_value = mock_items_response mock_client_holder = mocker.patch( - "app.endpoints.conversations.AsyncLlamaStackClientHolder" + "app.endpoints.conversations_v1.AsyncLlamaStackClientHolder" ) mock_client_holder.return_value.get_client.return_value = mock_client response = await get_conversation_endpoint_handler( @@ -635,30 +675,45 @@ async def test_successful_conversation_retrieval( self, mocker: MockerFixture, setup_configuration: AppConfig, - mock_session_data: dict[str, Any], expected_chat_history: list[dict[str, Any]], dummy_request: Request, - ) -> None: # pylint: disable=too-many-arguments,too-many-positional-arguments + ) -> None: """Test successful conversation retrieval with simplified response structure.""" mock_authorization_resolvers(mocker) - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) - mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - mocker.patch("app.endpoints.conversations.can_access_conversation") - mocker.patch("app.endpoints.conversations.retrieve_conversation") - - # Mock AsyncLlamaStackClientHolder - mock_client = mocker.AsyncMock() - mock_client.agents.session.list.return_value = mocker.Mock( - data=[mock_session_data] + mocker.patch( + "app.endpoints.conversations_v1.configuration", setup_configuration ) + mocker.patch("app.endpoints.conversations_v1.check_suid", return_value=True) + mocker.patch("app.endpoints.conversations_v1.can_access_conversation") + mocker.patch("app.endpoints.conversations_v1.retrieve_conversation") + + # Mock UserTurn database queries - create mock db_turns matching expected turns + mock_db_turns = [ + create_mock_db_turn( + mocker, 1, "2024-01-01T00:01:00Z", "2024-01-01T00:01:05Z" + ), + create_mock_db_turn( + mocker, 2, "2024-01-01T00:02:00Z", "2024-01-01T00:02:03Z" + ), + ] + mock_database_session(mocker, db_turns=mock_db_turns) - # Mock session.retrieve to return an object with model_dump() method - mock_session_retrieve_result = mocker.Mock() - mock_session_retrieve_result.model_dump.return_value = mock_session_data - mock_client.agents.session.retrieve.return_value = mock_session_retrieve_result + # Mock AsyncLlamaStackClientHolder - conversations_v1 uses conversations.items.list + mock_client = mocker.AsyncMock() + # Create mock items that will produce 2 turns (user + assistant messages) + mock_items = mocker.Mock() + mock_items.data = [ + mocker.Mock(type="message", role="user", content="Hello"), + mocker.Mock(type="message", role="assistant", content="Hi there!"), + mocker.Mock(type="message", role="user", content="How are you?"), + mocker.Mock( + type="message", role="assistant", content="I'm doing well, thanks!" + ), + ] + mock_client.conversations.items.list.return_value = mock_items mock_client_holder = mocker.patch( - "app.endpoints.conversations.AsyncLlamaStackClientHolder" + "app.endpoints.conversations_v1.AsyncLlamaStackClientHolder" ) mock_client_holder.return_value.get_client.return_value = mock_client @@ -668,10 +723,9 @@ async def test_successful_conversation_retrieval( assert isinstance(response, ConversationResponse) assert response.conversation_id == VALID_CONVERSATION_ID - assert response.chat_history == expected_chat_history - mock_client.agents.session.list.assert_called_once_with( - agent_id=VALID_CONVERSATION_ID - ) + # Convert ConversationTurn objects to dicts for comparison + actual_history = [turn.model_dump() for turn in response.chat_history] + assert actual_history == expected_chat_history @pytest.mark.asyncio async def test_retrieve_conversation_returns_none( @@ -682,11 +736,13 @@ async def test_retrieve_conversation_returns_none( ) -> None: """Test when retrieve_conversation returns None.""" mock_authorization_resolvers(mocker) - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) - mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - mocker.patch("app.endpoints.conversations.can_access_conversation") mocker.patch( - "app.endpoints.conversations.retrieve_conversation", return_value=None + "app.endpoints.conversations_v1.configuration", setup_configuration + ) + mocker.patch("app.endpoints.conversations_v1.check_suid", return_value=True) + mocker.patch("app.endpoints.conversations_v1.can_access_conversation") + mocker.patch( + "app.endpoints.conversations_v1.retrieve_conversation", return_value=None ) with pytest.raises(HTTPException) as exc_info: @@ -711,19 +767,26 @@ async def test_no_sessions_found_in_get_conversation( ) -> None: """Test when no sessions are found for the conversation.""" mock_authorization_resolvers(mocker) - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) - mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - mocker.patch("app.endpoints.conversations.can_access_conversation") mocker.patch( - "app.endpoints.conversations.retrieve_conversation", + "app.endpoints.conversations_v1.configuration", setup_configuration + ) + mocker.patch("app.endpoints.conversations_v1.check_suid", return_value=True) + mocker.patch("app.endpoints.conversations_v1.can_access_conversation") + mocker.patch( + "app.endpoints.conversations_v1.retrieve_conversation", return_value=mock_conversation, ) - # Mock AsyncLlamaStackClientHolder with empty sessions list + # Mock database session (empty db_turns since API call fails before query) + mock_database_session(mocker, db_turns=[]) + + # Mock AsyncLlamaStackClientHolder with empty items list mock_client = mocker.AsyncMock() - mock_client.agents.session.list.return_value = mocker.Mock(data=[]) + mock_items_response = mocker.Mock() + mock_items_response.data = [] # Empty items list + mock_client.conversations.items.list.return_value = mock_items_response mock_client_holder = mocker.patch( - "app.endpoints.conversations.AsyncLlamaStackClientHolder" + "app.endpoints.conversations_v1.AsyncLlamaStackClientHolder" ) mock_client_holder.return_value.get_client.return_value = mock_client @@ -749,27 +812,89 @@ async def test_sqlalchemy_error_in_get_conversation( ) -> None: """Test when SQLAlchemyError is raised during conversation retrieval.""" mock_authorization_resolvers(mocker) - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) - mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - mocker.patch("app.endpoints.conversations.can_access_conversation") mocker.patch( - "app.endpoints.conversations.retrieve_conversation", + "app.endpoints.conversations_v1.configuration", setup_configuration + ) + mocker.patch("app.endpoints.conversations_v1.check_suid", return_value=True) + mocker.patch("app.endpoints.conversations_v1.can_access_conversation") + mocker.patch( + "app.endpoints.conversations_v1.retrieve_conversation", return_value=mock_conversation, ) - # Mock AsyncLlamaStackClientHolder - SQLAlchemyError should come from session.retrieve + # Mock AsyncLlamaStackClientHolder - return items successfully mock_client = mocker.AsyncMock() - mock_session_list_response = mocker.Mock() - mock_session_list_response.data = [{"session_id": VALID_CONVERSATION_ID}] - mock_client.agents.session.list.return_value = mock_session_list_response - mock_client.agents.session.retrieve.side_effect = SQLAlchemyError( - "Database error" - ) + mock_items_response = mocker.Mock() + mock_items_response.data = [ + mocker.Mock(type="message", role="user", content="Hello"), + mocker.Mock(type="message", role="assistant", content="Hi!"), + ] + mock_client.conversations.items.list.return_value = mock_items_response mock_client_holder = mocker.patch( - "app.endpoints.conversations.AsyncLlamaStackClientHolder" + "app.endpoints.conversations_v1.AsyncLlamaStackClientHolder" ) mock_client_holder.return_value.get_client.return_value = mock_client + # Mock database session - SQLAlchemyError should come from UserTurn query + mock_session = mocker.Mock() + + # Make the UserTurn query raise SQLAlchemyError + def query_side_effect(model_class: type[Any]) -> Any: + if model_class == UserTurn: + mock_query = mocker.Mock() + mock_query.filter_by.return_value.order_by.return_value.all.side_effect = SQLAlchemyError( # pylint: disable=line-too-long + "Database error" + ) + return mock_query + # Other queries work normally + mock_query = mocker.Mock() + mock_query.all.return_value = [] + mock_query.filter_by.return_value.all.return_value = [] + mock_query.filter_by.return_value.first.return_value = None + return mock_query + + mock_session.query.side_effect = query_side_effect + mock_session_context = mocker.MagicMock() + mock_session_context.__enter__.return_value = mock_session + mock_session_context.__exit__.return_value = None + mocker.patch( + "app.endpoints.conversations_v1.get_session", + return_value=mock_session_context, + ) + + with pytest.raises(HTTPException) as exc_info: + await get_conversation_endpoint_handler( + request=dummy_request, + conversation_id=VALID_CONVERSATION_ID, + auth=MOCK_AUTH, + ) + + assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + detail = exc_info.value.detail + assert isinstance(detail, dict) + assert "Database" in detail["response"] # type: ignore + + @pytest.mark.asyncio + async def test_sqlalchemy_error_in_retrieve_conversation( + self, + mocker: MockerFixture, + setup_configuration: AppConfig, + dummy_request: Request, + ) -> None: + """Test when SQLAlchemyError is raised during retrieve_conversation call.""" + mock_authorization_resolvers(mocker) + mocker.patch( + "app.endpoints.conversations_v1.configuration", setup_configuration + ) + mocker.patch("app.endpoints.conversations_v1.check_suid", return_value=True) + mocker.patch("app.endpoints.conversations_v1.can_access_conversation") + + # Make retrieve_conversation raise SQLAlchemyError + mocker.patch( + "app.endpoints.conversations_v1.retrieve_conversation", + side_effect=SQLAlchemyError("Database error"), + ) + with pytest.raises(HTTPException) as exc_info: await get_conversation_endpoint_handler( request=dummy_request, @@ -793,7 +918,7 @@ async def test_configuration_not_loaded( """Test the endpoint when configuration is not loaded.""" mock_authorization_resolvers(mocker) mock_config = AppConfig() - mocker.patch("app.endpoints.conversations.configuration", mock_config) + mocker.patch("app.endpoints.conversations_v1.configuration", mock_config) with pytest.raises(HTTPException) as exc_info: await delete_conversation_endpoint_handler( @@ -817,8 +942,10 @@ async def test_invalid_conversation_id_format( ) -> None: """Test the endpoint with an invalid conversation ID format.""" mock_authorization_resolvers(mocker) - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) - mocker.patch("app.endpoints.conversations.check_suid", return_value=False) + mocker.patch( + "app.endpoints.conversations_v1.configuration", setup_configuration + ) + mocker.patch("app.endpoints.conversations_v1.check_suid", return_value=False) with pytest.raises(HTTPException) as exc_info: await delete_conversation_endpoint_handler( @@ -843,18 +970,25 @@ async def test_llama_stack_connection_error( ) -> None: """Test the endpoint when LlamaStack connection fails.""" mock_authorization_resolvers(mocker) - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) - mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - mocker.patch("app.endpoints.conversations.can_access_conversation") - mocker.patch("app.endpoints.conversations.retrieve_conversation") + mocker.patch( + "app.endpoints.conversations_v1.configuration", setup_configuration + ) + mocker.patch("app.endpoints.conversations_v1.check_suid", return_value=True) + mocker.patch("app.endpoints.conversations_v1.can_access_conversation") + mocker.patch("app.endpoints.conversations_v1.retrieve_conversation") + + # Mock delete_conversation to succeed locally + mocker.patch( + "app.endpoints.conversations_v1.delete_conversation", return_value=True + ) # Mock AsyncLlamaStackClientHolder to raise APIConnectionError mock_client = mocker.AsyncMock() - mock_client.agents.session.delete.side_effect = APIConnectionError( + mock_client.conversations.delete.side_effect = APIConnectionError( request=None # type: ignore ) mock_client_holder = mocker.patch( - "app.endpoints.conversations.AsyncLlamaStackClientHolder" + "app.endpoints.conversations_v1.AsyncLlamaStackClientHolder" ) mock_client_holder.return_value.get_client.return_value = mock_client @@ -868,6 +1002,7 @@ async def test_llama_stack_connection_error( assert exc_info.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE detail = exc_info.value.detail assert isinstance(detail, dict) + # ServiceUnavailableResponse model_dump() creates detail with response and cause assert detail["response"] == "Unable to connect to Llama Stack" # type: ignore @pytest.mark.asyncio @@ -879,69 +1014,41 @@ async def test_llama_stack_not_found_error( ) -> None: """Test the endpoint when LlamaStack returns NotFoundError.""" mock_authorization_resolvers(mocker) - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) - mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - mocker.patch("app.endpoints.conversations.can_access_conversation") - mocker.patch("app.endpoints.conversations.retrieve_conversation") - - # Mock AsyncLlamaStackClientHolder to raise NotFoundError - mock_client = mocker.AsyncMock() - mock_client.agents.session.delete.side_effect = NotFoundError( - message="Session not found", response=mocker.Mock(request=None), body=None - ) - mock_client_holder = mocker.patch( - "app.endpoints.conversations.AsyncLlamaStackClientHolder" + mocker.patch( + "app.endpoints.conversations_v1.configuration", setup_configuration ) - mock_client_holder.return_value.get_client.return_value = mock_client - - with pytest.raises(HTTPException) as exc_info: - await delete_conversation_endpoint_handler( - request=dummy_request, - conversation_id=VALID_CONVERSATION_ID, - auth=MOCK_AUTH, - ) - - assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND - detail = exc_info.value.detail - assert isinstance(detail, dict) - assert "Conversation not found" in detail["response"] # type: ignore - assert "does not exist" in detail["cause"] # type: ignore - assert VALID_CONVERSATION_ID in detail["cause"] # type: ignore + mocker.patch("app.endpoints.conversations_v1.check_suid", return_value=True) + mocker.patch("app.endpoints.conversations_v1.can_access_conversation") + mocker.patch("app.endpoints.conversations_v1.retrieve_conversation") - @pytest.mark.asyncio - async def test_session_deletion_exception( - self, - mocker: MockerFixture, - setup_configuration: AppConfig, - dummy_request: Request, - ) -> None: - """Test the endpoint when session deletion raises an exception.""" - mock_authorization_resolvers(mocker) - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) - mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - mocker.patch("app.endpoints.conversations.can_access_conversation") - mocker.patch("app.endpoints.conversations.retrieve_conversation") + # Mock delete_conversation to succeed locally + mocker.patch( + "app.endpoints.conversations_v1.delete_conversation", return_value=True + ) - # Mock AsyncLlamaStackClientHolder to raise a general exception + # Mock AsyncLlamaStackClientHolder - NotFoundError is caught and treated as success mock_client = mocker.AsyncMock() - mock_client.agents.session.delete.side_effect = APIConnectionError( - request=None # type: ignore + mock_client.conversations.delete.side_effect = APIStatusError( + message="Conversation not found", + response=mocker.Mock(status_code=404, request=None), + body=None, ) mock_client_holder = mocker.patch( - "app.endpoints.conversations.AsyncLlamaStackClientHolder" + "app.endpoints.conversations_v1.AsyncLlamaStackClientHolder" ) mock_client_holder.return_value.get_client.return_value = mock_client - with pytest.raises(HTTPException) as exc_info: - await delete_conversation_endpoint_handler( - request=dummy_request, - conversation_id=VALID_CONVERSATION_ID, - auth=MOCK_AUTH, - ) - assert exc_info.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE - detail = exc_info.value.detail - assert isinstance(detail, dict) - assert "Unable to connect to Llama Stack" in detail["response"] # type: ignore + # NotFoundError is caught and treated as already deleted, so it succeeds + response = await delete_conversation_endpoint_handler( + request=dummy_request, + conversation_id=VALID_CONVERSATION_ID, + auth=MOCK_AUTH, + ) + + assert isinstance(response, ConversationDeleteResponse) + assert response.conversation_id == VALID_CONVERSATION_ID + assert response.success is True + assert "deleted successfully" in response.response @pytest.mark.asyncio async def test_delete_conversation_forbidden( @@ -952,10 +1059,12 @@ async def test_delete_conversation_forbidden( mock_conversation: MockType, ) -> None: """Test forbidden deletion when user lacks permission to delete conversation.""" - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) - mocker.patch("app.endpoints.conversations.check_suid", return_value=True) mocker.patch( - "app.endpoints.conversations.retrieve_conversation", + "app.endpoints.conversations_v1.configuration", setup_configuration + ) + mocker.patch("app.endpoints.conversations_v1.check_suid", return_value=True) + mocker.patch( + "app.endpoints.conversations_v1.retrieve_conversation", return_value=mock_conversation, ) mocker.patch( @@ -1010,34 +1119,39 @@ async def test_delete_others_conversations_allowed_for_authorized_user( }, ) # Allow user to detele other users' conversations - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) - mocker.patch("app.endpoints.conversations.check_suid", return_value=True) mocker.patch( - "app.endpoints.conversations.retrieve_conversation", + "app.endpoints.conversations_v1.configuration", setup_configuration + ) + mocker.patch("app.endpoints.conversations_v1.check_suid", return_value=True) + mocker.patch( + "app.endpoints.conversations_v1.retrieve_conversation", return_value=mock_conversation, ) - mock_client = mocker.AsyncMock() - mock_client.agents.session.list.return_value.data = [ - {"session_id": VALID_CONVERSATION_ID} - ] - mock_client.agents.session.delete.return_value = None + # Mock delete_conversation to succeed locally mocker.patch( - "app.endpoints.conversations.AsyncLlamaStackClientHolder.get_client", - return_value=mock_client, + "app.endpoints.conversations_v1.delete_conversation", return_value=True ) - mocker.patch( - "app.endpoints.conversations.delete_conversation", return_value=None + # Mock AsyncLlamaStackClientHolder - conversations_v1 uses conversations.delete + mock_client = mocker.AsyncMock() + mock_delete_response = mocker.Mock() + mock_delete_response.deleted = True + mock_client.conversations.delete.return_value = mock_delete_response + mock_client_holder = mocker.patch( + "app.endpoints.conversations_v1.AsyncLlamaStackClientHolder" ) + mock_client_holder.return_value.get_client.return_value = mock_client + response = await delete_conversation_endpoint_handler( request=dummy_request, conversation_id=VALID_CONVERSATION_ID, auth=MOCK_AUTH, ) - assert response.success is True + assert isinstance(response, ConversationDeleteResponse) assert response.conversation_id == VALID_CONVERSATION_ID + assert response.success is True assert "deleted successfully" in response.response @pytest.mark.asyncio @@ -1049,23 +1163,25 @@ async def test_successful_conversation_deletion( ) -> None: """Test successful conversation deletion.""" mock_authorization_resolvers(mocker) - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) - mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - mocker.patch("app.endpoints.conversations.can_access_conversation") - mocker.patch("app.endpoints.conversations.retrieve_conversation") + mocker.patch( + "app.endpoints.conversations_v1.configuration", setup_configuration + ) + mocker.patch("app.endpoints.conversations_v1.check_suid", return_value=True) + mocker.patch("app.endpoints.conversations_v1.can_access_conversation") + mocker.patch("app.endpoints.conversations_v1.retrieve_conversation") - # Mock the delete_conversation function - mocker.patch("app.endpoints.conversations.delete_conversation") + # Mock the delete_conversation function to return True (successful deletion) + mock_delete = mocker.patch( + "app.endpoints.conversations_v1.delete_conversation", return_value=True + ) - # Mock AsyncLlamaStackClientHolder + # Mock AsyncLlamaStackClientHolder - conversations_v1 uses conversations.delete mock_client = mocker.AsyncMock() - # Ensure the endpoint sees an existing session so it proceeds to delete - mock_client.agents.session.list.return_value = mocker.Mock( - data=[{"session_id": VALID_CONVERSATION_ID}] - ) - mock_client.agents.session.delete.return_value = None # Successful deletion + mock_delete_response = mocker.Mock() + mock_delete_response.deleted = True + mock_client.conversations.delete.return_value = mock_delete_response mock_client_holder = mocker.patch( - "app.endpoints.conversations.AsyncLlamaStackClientHolder" + "app.endpoints.conversations_v1.AsyncLlamaStackClientHolder" ) mock_client_holder.return_value.get_client.return_value = mock_client @@ -1076,10 +1192,9 @@ async def test_successful_conversation_deletion( assert isinstance(response, ConversationDeleteResponse) assert response.conversation_id == VALID_CONVERSATION_ID assert response.success is True - assert response.response == "Conversation deleted successfully" - mock_client.agents.session.delete.assert_called_once_with( - agent_id=VALID_CONVERSATION_ID, session_id=VALID_CONVERSATION_ID - ) + assert "deleted successfully" in response.response + mock_delete.assert_called_once() + mock_client.conversations.delete.assert_called_once() @pytest.mark.asyncio async def test_retrieve_conversation_returns_none_in_delete( @@ -1088,53 +1203,32 @@ async def test_retrieve_conversation_returns_none_in_delete( setup_configuration: AppConfig, dummy_request: Request, ) -> None: - """Test when retrieve_conversation returns None in delete endpoint.""" + """Test when conversation doesn't exist in delete endpoint.""" mock_authorization_resolvers(mocker) - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) - mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - mocker.patch("app.endpoints.conversations.can_access_conversation") mocker.patch( - "app.endpoints.conversations.retrieve_conversation", return_value=None + "app.endpoints.conversations_v1.configuration", setup_configuration ) - - with pytest.raises(HTTPException) as exc_info: - await delete_conversation_endpoint_handler( - request=dummy_request, - conversation_id=VALID_CONVERSATION_ID, - auth=MOCK_AUTH, - ) - - assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND - detail = exc_info.value.detail - assert isinstance(detail, dict) - assert "Conversation not found" in detail["response"] # type: ignore - - @pytest.mark.asyncio - async def test_no_sessions_found_in_delete( - self, - mocker: MockerFixture, - setup_configuration: AppConfig, - dummy_request: Request, - mock_conversation: MockType, - ) -> None: - """Test when no sessions are found in delete endpoint (early return).""" - mock_authorization_resolvers(mocker) - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) - mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - mocker.patch("app.endpoints.conversations.can_access_conversation") + mocker.patch("app.endpoints.conversations_v1.check_suid", return_value=True) + # can_access_conversation returns True when conversation doesn't exist mocker.patch( - "app.endpoints.conversations.retrieve_conversation", - return_value=mock_conversation, + "app.endpoints.conversations_v1.can_access_conversation", return_value=True + ) + # delete_conversation returns False when conversation doesn't exist + mocker.patch( + "app.endpoints.conversations_v1.delete_conversation", return_value=False ) - # Mock AsyncLlamaStackClientHolder with empty sessions list + # Mock AsyncLlamaStackClientHolder - conversations_v1 uses conversations.delete mock_client = mocker.AsyncMock() - mock_client.agents.session.list.return_value = mocker.Mock(data=[]) + mock_delete_response = mocker.Mock() + mock_delete_response.deleted = True + mock_client.conversations.delete.return_value = mock_delete_response mock_client_holder = mocker.patch( - "app.endpoints.conversations.AsyncLlamaStackClientHolder" + "app.endpoints.conversations_v1.AsyncLlamaStackClientHolder" ) mock_client_holder.return_value.get_client.return_value = mock_client + # Delete endpoint succeeds but returns deleted=False response = await delete_conversation_endpoint_handler( request=dummy_request, conversation_id=VALID_CONVERSATION_ID, @@ -1143,8 +1237,8 @@ async def test_no_sessions_found_in_delete( assert isinstance(response, ConversationDeleteResponse) assert response.conversation_id == VALID_CONVERSATION_ID - assert response.success is True # Operation completed successfully - assert "cannot be deleted" in response.response # But nothing was deleted + assert response.success is True + assert "cannot be deleted" in response.response # Not found locally @pytest.mark.asyncio async def test_sqlalchemy_error_in_delete( @@ -1156,11 +1250,13 @@ async def test_sqlalchemy_error_in_delete( ) -> None: """Test when SQLAlchemyError is raised during conversation deletion.""" mock_authorization_resolvers(mocker) - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) - mocker.patch("app.endpoints.conversations.check_suid", return_value=True) - mocker.patch("app.endpoints.conversations.can_access_conversation") mocker.patch( - "app.endpoints.conversations.retrieve_conversation", + "app.endpoints.conversations_v1.configuration", setup_configuration + ) + mocker.patch("app.endpoints.conversations_v1.check_suid", return_value=True) + mocker.patch("app.endpoints.conversations_v1.can_access_conversation") + mocker.patch( + "app.endpoints.conversations_v1.retrieve_conversation", return_value=mock_conversation, ) @@ -1171,13 +1267,13 @@ async def test_sqlalchemy_error_in_delete( mock_client.agents.session.list.return_value = mock_session_list_response mock_client.agents.session.delete.return_value = None mock_client_holder = mocker.patch( - "app.endpoints.conversations.AsyncLlamaStackClientHolder" + "app.endpoints.conversations_v1.AsyncLlamaStackClientHolder" ) mock_client_holder.return_value.get_client.return_value = mock_client # Mock delete_conversation to raise SQLAlchemyError mocker.patch( - "app.endpoints.conversations.delete_conversation", + "app.endpoints.conversations_v1.delete_conversation", side_effect=SQLAlchemyError("Database error"), ) @@ -1205,7 +1301,7 @@ async def test_configuration_not_loaded( """Test the endpoint when configuration is not loaded.""" mock_authorization_resolvers(mocker) mock_config = AppConfig() - mocker.patch("app.endpoints.conversations.configuration", mock_config) + mocker.patch("app.endpoints.conversations_v1.configuration", mock_config) with pytest.raises(HTTPException) as exc_info: await get_conversations_list_endpoint_handler( @@ -1226,7 +1322,9 @@ async def test_successful_conversations_list_retrieval( ) -> None: """Test successful retrieval of conversations list.""" mock_authorization_resolvers(mocker) - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) + mocker.patch( + "app.endpoints.conversations_v1.configuration", setup_configuration + ) # Mock database session and query results mock_conversations = [ @@ -1289,7 +1387,9 @@ async def test_empty_conversations_list( ) -> None: """Test when user has no conversations.""" mock_authorization_resolvers(mocker) - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) + mocker.patch( + "app.endpoints.conversations_v1.configuration", setup_configuration + ) # Mock database session with no results mock_database_session(mocker, []) @@ -1311,7 +1411,9 @@ async def test_database_exception( ) -> None: """Test when database query raises an exception.""" mock_authorization_resolvers(mocker) - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) + mocker.patch( + "app.endpoints.conversations_v1.configuration", setup_configuration + ) # Mock database session to raise exception mock_session = mock_database_session(mocker) @@ -1331,7 +1433,9 @@ async def test_sqlalchemy_error_in_list( ) -> None: """Test when database query raises SQLAlchemyError.""" mock_authorization_resolvers(mocker) - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) + mocker.patch( + "app.endpoints.conversations_v1.configuration", setup_configuration + ) # Mock database session to raise SQLAlchemyError when all() is called # Since dummy_request has all actions, it will use query directly (not filter_by) @@ -1348,7 +1452,8 @@ async def test_sqlalchemy_error_in_list( mock_session_context.__enter__.return_value = mock_session mock_session_context.__exit__.return_value = None mocker.patch( - "app.endpoints.conversations.get_session", return_value=mock_session_context + "app.endpoints.conversations_v1.get_session", + return_value=mock_session_context, ) with pytest.raises(HTTPException) as exc_info: @@ -1370,7 +1475,9 @@ async def test_conversations_list_with_none_topic_summary( ) -> None: """Test conversations list when topic_summary is None.""" mock_authorization_resolvers(mocker) - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) + mocker.patch( + "app.endpoints.conversations_v1.configuration", setup_configuration + ) # Mock database session with conversation having None topic_summary mock_conversations = [ @@ -1407,7 +1514,9 @@ async def test_conversations_list_with_mixed_topic_summaries( ) -> None: """Test conversations list with mixed topic_summary values (some None, some not).""" mock_authorization_resolvers(mocker) - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) + mocker.patch( + "app.endpoints.conversations_v1.configuration", setup_configuration + ) # Mock database session with mixed topic_summary values mock_conversations = [ @@ -1475,7 +1584,9 @@ async def test_conversations_list_with_empty_topic_summary( ) -> None: """Test conversations list when topic_summary is an empty string.""" mock_authorization_resolvers(mocker) - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) + mocker.patch( + "app.endpoints.conversations_v1.configuration", setup_configuration + ) # Mock database session with conversation having empty topic_summary mock_conversations = [ @@ -1512,7 +1623,9 @@ async def test_conversations_list_topic_summary_field_presence( ) -> None: """Test that topic_summary field is always present in ConversationDetails objects.""" mock_authorization_resolvers(mocker) - mocker.patch("app.endpoints.conversations.configuration", setup_configuration) + mocker.patch( + "app.endpoints.conversations_v1.configuration", setup_configuration + ) # Mock database session with conversations mock_conversations = [ @@ -1546,3 +1659,380 @@ async def test_conversations_list_topic_summary_field_presence( conv_dict = conv.model_dump() assert "topic_summary" in conv_dict assert conv_dict["topic_summary"] == "Test topic summary" + + +class TestUpdateConversationEndpoint: + """Test cases for the PUT /conversations/{conversation_id} endpoint.""" + + @pytest.mark.asyncio + async def test_configuration_not_loaded( + self, mocker: MockerFixture, dummy_request: Request + ) -> None: + """Test the endpoint when configuration is not loaded.""" + mock_authorization_resolvers(mocker) + mock_config = AppConfig() + mocker.patch("app.endpoints.conversations_v1.configuration", mock_config) + + update_request = ConversationUpdateRequest(topic_summary="New topic") + + with pytest.raises(HTTPException) as exc_info: + await update_conversation_endpoint_handler( + request=dummy_request, + conversation_id=VALID_CONVERSATION_ID, + update_request=update_request, + auth=MOCK_AUTH, + ) + + assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + detail = exc_info.value.detail + assert isinstance(detail, dict) + assert "Configuration is not loaded" in detail["response"] # type: ignore + + @pytest.mark.asyncio + async def test_invalid_conversation_id_format( + self, + mocker: MockerFixture, + setup_configuration: AppConfig, + dummy_request: Request, + ) -> None: + """Test the endpoint with an invalid conversation ID format.""" + mock_authorization_resolvers(mocker) + mocker.patch( + "app.endpoints.conversations_v1.configuration", setup_configuration + ) + mocker.patch("app.endpoints.conversations_v1.check_suid", return_value=False) + + update_request = ConversationUpdateRequest(topic_summary="New topic") + + with pytest.raises(HTTPException) as exc_info: + await update_conversation_endpoint_handler( + request=dummy_request, + conversation_id=INVALID_CONVERSATION_ID, + update_request=update_request, + auth=MOCK_AUTH, + ) + + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + detail = exc_info.value.detail + assert isinstance(detail, dict) + assert "Invalid conversation ID format" in detail["response"] # type: ignore + + @pytest.mark.asyncio + async def test_update_conversation_forbidden( + self, + mocker: MockerFixture, + setup_configuration: AppConfig, + dummy_request: Request, + mock_conversation: MockType, + ) -> None: + """Test forbidden access when user lacks permission to update conversation.""" + mock_authorization_resolvers(mocker) + mocker.patch( + "app.endpoints.conversations_v1.configuration", setup_configuration + ) + mocker.patch("app.endpoints.conversations_v1.check_suid", return_value=True) + mocker.patch( + "app.endpoints.conversations_v1.retrieve_conversation", + return_value=mock_conversation, + ) + mocker.patch( + "authorization.resolvers.NoopAccessResolver.get_actions", + return_value=set(Action.UPDATE_CONVERSATION), + ) # User can only update their own conversations + + # Mock can_access_conversation to return False (user doesn't have access) + mocker.patch( + "app.endpoints.conversations_v1.can_access_conversation", return_value=False + ) + + update_request = ConversationUpdateRequest(topic_summary="New topic") + + with pytest.raises(HTTPException) as exc_info: + await update_conversation_endpoint_handler( + request=dummy_request, + conversation_id=VALID_CONVERSATION_ID, + update_request=update_request, + auth=MOCK_AUTH, + ) + + assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN + detail = exc_info.value.detail + assert isinstance(detail, dict) + assert "does not have permission" in detail["cause"] # type: ignore + + @pytest.mark.asyncio + async def test_conversation_not_found_in_update( + self, + mocker: MockerFixture, + setup_configuration: AppConfig, + dummy_request: Request, + ) -> None: + """Test when conversation is not found in update endpoint.""" + mock_authorization_resolvers(mocker) + mocker.patch( + "app.endpoints.conversations_v1.configuration", setup_configuration + ) + mocker.patch("app.endpoints.conversations_v1.check_suid", return_value=True) + mocker.patch("app.endpoints.conversations_v1.can_access_conversation") + mocker.patch( + "app.endpoints.conversations_v1.retrieve_conversation", return_value=None + ) + + update_request = ConversationUpdateRequest(topic_summary="New topic") + + with pytest.raises(HTTPException) as exc_info: + await update_conversation_endpoint_handler( + request=dummy_request, + conversation_id=VALID_CONVERSATION_ID, + update_request=update_request, + auth=MOCK_AUTH, + ) + + assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND + detail = exc_info.value.detail + assert isinstance(detail, dict) + assert "Conversation not found" in detail["response"] # type: ignore + + @pytest.mark.asyncio + async def test_sqlalchemy_error_in_retrieve_conversation_update( + self, + mocker: MockerFixture, + setup_configuration: AppConfig, + dummy_request: Request, + ) -> None: + """Test when SQLAlchemyError is raised during retrieve_conversation in update.""" + mock_authorization_resolvers(mocker) + mocker.patch( + "app.endpoints.conversations_v1.configuration", setup_configuration + ) + mocker.patch("app.endpoints.conversations_v1.check_suid", return_value=True) + mocker.patch("app.endpoints.conversations_v1.can_access_conversation") + mocker.patch( + "app.endpoints.conversations_v1.retrieve_conversation", + side_effect=SQLAlchemyError("Database error"), + ) + + update_request = ConversationUpdateRequest(topic_summary="New topic") + + with pytest.raises(HTTPException) as exc_info: + await update_conversation_endpoint_handler( + request=dummy_request, + conversation_id=VALID_CONVERSATION_ID, + update_request=update_request, + auth=MOCK_AUTH, + ) + + assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + detail = exc_info.value.detail + assert isinstance(detail, dict) + assert "Database" in detail["response"] # type: ignore + + @pytest.mark.asyncio + async def test_successful_conversation_update( + self, + mocker: MockerFixture, + setup_configuration: AppConfig, + dummy_request: Request, + mock_conversation: MockType, + ) -> None: + """Test successful conversation update.""" + mock_authorization_resolvers(mocker) + mocker.patch( + "app.endpoints.conversations_v1.configuration", setup_configuration + ) + mocker.patch("app.endpoints.conversations_v1.check_suid", return_value=True) + mocker.patch("app.endpoints.conversations_v1.can_access_conversation") + mocker.patch( + "app.endpoints.conversations_v1.retrieve_conversation", + return_value=mock_conversation, + ) + + # Mock database session for update + mock_session = mocker.Mock() + mock_db_conv = mocker.Mock() + mock_db_conv.topic_summary = None + mock_session.query.return_value.filter_by.return_value.first.return_value = ( + mock_db_conv + ) + mock_session_context = mocker.MagicMock() + mock_session_context.__enter__.return_value = mock_session + mock_session_context.__exit__.return_value = None + mocker.patch( + "app.endpoints.conversations_v1.get_session", + return_value=mock_session_context, + ) + + # Mock AsyncLlamaStackClientHolder + mock_client = mocker.AsyncMock() + mock_client.conversations.update.return_value = None + mock_client_holder = mocker.patch( + "app.endpoints.conversations_v1.AsyncLlamaStackClientHolder" + ) + mock_client_holder.return_value.get_client.return_value = mock_client + + update_request = ConversationUpdateRequest(topic_summary="New topic summary") + + response = await update_conversation_endpoint_handler( + request=dummy_request, + conversation_id=VALID_CONVERSATION_ID, + update_request=update_request, + auth=MOCK_AUTH, + ) + + assert isinstance(response, ConversationUpdateResponse) + assert response.conversation_id == VALID_CONVERSATION_ID + assert response.success is True + assert "updated successfully" in response.message + mock_client.conversations.update.assert_called_once() + mock_session.commit.assert_called_once() + + @pytest.mark.asyncio + async def test_llama_stack_connection_error_in_update( + self, + mocker: MockerFixture, + setup_configuration: AppConfig, + dummy_request: Request, + mock_conversation: MockType, + ) -> None: + """Test the endpoint when LlamaStack connection fails during update.""" + mock_authorization_resolvers(mocker) + mocker.patch( + "app.endpoints.conversations_v1.configuration", setup_configuration + ) + mocker.patch("app.endpoints.conversations_v1.check_suid", return_value=True) + mocker.patch("app.endpoints.conversations_v1.can_access_conversation") + mocker.patch( + "app.endpoints.conversations_v1.retrieve_conversation", + return_value=mock_conversation, + ) + + # Mock AsyncLlamaStackClientHolder to raise APIConnectionError + mock_client = mocker.AsyncMock() + mock_client.conversations.update.side_effect = APIConnectionError( + request=None # type: ignore + ) + mock_client_holder = mocker.patch( + "app.endpoints.conversations_v1.AsyncLlamaStackClientHolder" + ) + mock_client_holder.return_value.get_client.return_value = mock_client + + update_request = ConversationUpdateRequest(topic_summary="New topic") + + with pytest.raises(HTTPException) as exc_info: + await update_conversation_endpoint_handler( + request=dummy_request, + conversation_id=VALID_CONVERSATION_ID, + update_request=update_request, + auth=MOCK_AUTH, + ) + + assert exc_info.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE + detail = exc_info.value.detail + assert isinstance(detail, dict) + assert detail["response"] == "Unable to connect to Llama Stack" # type: ignore + + @pytest.mark.asyncio + async def test_llama_stack_not_found_error_in_update( + self, + mocker: MockerFixture, + setup_configuration: AppConfig, + dummy_request: Request, + mock_conversation: MockType, + ) -> None: + """Test the endpoint when LlamaStack returns NotFoundError during update.""" + mock_authorization_resolvers(mocker) + mocker.patch( + "app.endpoints.conversations_v1.configuration", setup_configuration + ) + mocker.patch("app.endpoints.conversations_v1.check_suid", return_value=True) + mocker.patch("app.endpoints.conversations_v1.can_access_conversation") + mocker.patch( + "app.endpoints.conversations_v1.retrieve_conversation", + return_value=mock_conversation, + ) + + # Mock AsyncLlamaStackClientHolder to raise APIStatusError + mock_client = mocker.AsyncMock() + mock_client.conversations.update.side_effect = APIStatusError( + message="Conversation not found", + response=mocker.Mock(status_code=404, request=None), + body=None, + ) + mock_client_holder = mocker.patch( + "app.endpoints.conversations_v1.AsyncLlamaStackClientHolder" + ) + mock_client_holder.return_value.get_client.return_value = mock_client + + update_request = ConversationUpdateRequest(topic_summary="New topic") + + with pytest.raises(HTTPException) as exc_info: + await update_conversation_endpoint_handler( + request=dummy_request, + conversation_id=VALID_CONVERSATION_ID, + update_request=update_request, + auth=MOCK_AUTH, + ) + + assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND + detail = exc_info.value.detail + assert isinstance(detail, dict) + assert "Conversation not found" in detail["response"] # type: ignore + + @pytest.mark.asyncio + async def test_sqlalchemy_error_in_database_update( + self, + mocker: MockerFixture, + setup_configuration: AppConfig, + dummy_request: Request, + mock_conversation: MockType, + ) -> None: + """Test when SQLAlchemyError is raised during database update.""" + mock_authorization_resolvers(mocker) + mocker.patch( + "app.endpoints.conversations_v1.configuration", setup_configuration + ) + mocker.patch("app.endpoints.conversations_v1.check_suid", return_value=True) + mocker.patch("app.endpoints.conversations_v1.can_access_conversation") + mocker.patch( + "app.endpoints.conversations_v1.retrieve_conversation", + return_value=mock_conversation, + ) + + # Mock AsyncLlamaStackClientHolder - update succeeds + mock_client = mocker.AsyncMock() + mock_client.conversations.update.return_value = None + mock_client_holder = mocker.patch( + "app.endpoints.conversations_v1.AsyncLlamaStackClientHolder" + ) + mock_client_holder.return_value.get_client.return_value = mock_client + + # Mock database session - commit raises SQLAlchemyError + mock_session = mocker.Mock() + mock_db_conv = mocker.Mock() + mock_db_conv.topic_summary = None + mock_session.query.return_value.filter_by.return_value.first.return_value = ( + mock_db_conv + ) + mock_session.commit.side_effect = SQLAlchemyError("Database error") + mock_session_context = mocker.MagicMock() + mock_session_context.__enter__.return_value = mock_session + mock_session_context.__exit__.return_value = None + mocker.patch( + "app.endpoints.conversations_v1.get_session", + return_value=mock_session_context, + ) + + update_request = ConversationUpdateRequest(topic_summary="New topic") + + with pytest.raises(HTTPException) as exc_info: + await update_conversation_endpoint_handler( + request=dummy_request, + conversation_id=VALID_CONVERSATION_ID, + update_request=update_request, + auth=MOCK_AUTH, + ) + + assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + detail = exc_info.value.detail + assert isinstance(detail, dict) + assert "Database" in detail["response"] # type: ignore diff --git a/tests/unit/app/endpoints/test_conversations_v2.py b/tests/unit/app/endpoints/test_conversations_v2.py index 1ee7f8d86..57019dad8 100644 --- a/tests/unit/app/endpoints/test_conversations_v2.py +++ b/tests/unit/app/endpoints/test_conversations_v2.py @@ -3,18 +3,19 @@ """Unit tests for the /conversations REST API endpoints.""" from datetime import datetime, timezone +from typing import Any, cast import pytest from fastapi import HTTPException, status from pytest_mock import MockerFixture, MockType from app.endpoints.conversations_v2 import ( + build_conversation_turn_from_cache_entry, check_conversation_existence, check_valid_conversation_id, delete_conversation_endpoint_handler, get_conversation_endpoint_handler, get_conversations_list_endpoint_handler, - transform_chat_message, update_conversation_endpoint_handler, ) from configuration import AppConfig @@ -23,57 +24,20 @@ from models.responses import ( ConversationData, ConversationUpdateResponse, - ReferencedDocument, ) from tests.unit.utils.auth_helpers import mock_authorization_resolvers +from utils.types import ToolCallSummary, ToolResultSummary MOCK_AUTH = ("mock_user_id", "mock_username", False, "mock_token") VALID_CONVERSATION_ID = "123e4567-e89b-12d3-a456-426614174000" INVALID_CONVERSATION_ID = "invalid-id" -def test_transform_message() -> None: - """Test the transform_chat_message transformation function.""" - entry = CacheEntry( - query="query", - response="response", - provider="provider", - model="model", - started_at="2024-01-01T00:00:00Z", - completed_at="2024-01-01T00:00:05Z", - ) - transformed = transform_chat_message(entry) - assert transformed is not None +class TestBuildConversationTurnFromCacheEntry: + """Test cases for the build_conversation_turn_from_cache_entry utility function.""" - assert "provider" in transformed - assert transformed["provider"] == "provider" - - assert "model" in transformed - assert transformed["model"] == "model" - - assert "started_at" in transformed - assert transformed["started_at"] == "2024-01-01T00:00:00Z" - - assert "completed_at" in transformed - assert transformed["completed_at"] == "2024-01-01T00:00:05Z" - - assert "messages" in transformed - assert len(transformed["messages"]) == 2 - - message1 = transformed["messages"][0] - assert message1["type"] == "user" - assert message1["content"] == "query" - - message2 = transformed["messages"][1] - assert message2["type"] == "assistant" - assert message2["content"] == "response" - - -class TestTransformChatMessage: - """Test cases for the transform_chat_message utility function.""" - - def test_transform_message_without_documents(self) -> None: - """Test the transformation when no referenced_documents are present.""" + def test_build_turn_without_tool_calls(self) -> None: + """Test building a turn when no tool calls/results are present.""" entry = CacheEntry( query="query", response="response", @@ -81,41 +45,36 @@ def test_transform_message_without_documents(self) -> None: model="model", started_at="2024-01-01T00:00:00Z", completed_at="2024-01-01T00:00:05Z", - # referenced_documents is None by default + # tool_calls and tool_results are None by default ) - transformed = transform_chat_message(entry) - - assistant_message = transformed["messages"][1] - - # Assert that the key is NOT present when the list is None - assert "referenced_documents" not in assistant_message - - def test_transform_message_with_referenced_documents(self) -> None: - """Test the transformation when referenced_documents are present.""" - docs = [ - ReferencedDocument(doc_title="Test Doc", doc_url="http://example.com") - ] # type: ignore - entry = CacheEntry( - query="query", - response="response", - provider="provider", - model="model", - started_at="2024-01-01T00:00:00Z", - completed_at="2024-01-01T00:00:05Z", - referenced_documents=docs, - ) - - transformed = transform_chat_message(entry) - assistant_message = transformed["messages"][1] - - assert "referenced_documents" in assistant_message - ref_docs = assistant_message["referenced_documents"] - assert len(ref_docs) == 1 - assert ref_docs[0]["doc_title"] == "Test Doc" - assert str(ref_docs[0]["doc_url"]) == "http://example.com/" - - def test_transform_message_with_empty_referenced_documents(self) -> None: - """Test the transformation when referenced_documents is an empty list.""" + turn = build_conversation_turn_from_cache_entry(entry) + + assert turn.tool_calls == [] + assert turn.tool_results == [] + assert turn.provider == "provider" + assert turn.model == "model" + assert len(turn.messages) == 2 + + def test_build_turn_with_tool_calls(self) -> None: + """Test building a turn when tool calls and results are present.""" + + tool_calls = [ + ToolCallSummary( + id="call_1", + name="test_tool", + args={"arg1": "value1"}, + type="function_call", + ) + ] + tool_results = [ + ToolResultSummary( + id="call_1", + status="success", + content="result", + type="function_call_output", + round=1, + ) + ] entry = CacheEntry( query="query", response="response", @@ -123,14 +82,18 @@ def test_transform_message_with_empty_referenced_documents(self) -> None: model="model", started_at="2024-01-01T00:00:00Z", completed_at="2024-01-01T00:00:05Z", - referenced_documents=[], # Explicitly empty + tool_calls=tool_calls, + tool_results=tool_results, ) - transformed = transform_chat_message(entry) - assistant_message = transformed["messages"][1] + turn = build_conversation_turn_from_cache_entry(entry) - assert "referenced_documents" in assistant_message - assert assistant_message["referenced_documents"] == [] + assert turn.provider == "provider" + assert turn.model == "model" + assert len(turn.tool_calls) == 1 + assert turn.tool_calls[0].name == "test_tool" + assert len(turn.tool_results) == 1 + assert turn.tool_results[0].status == "success" @pytest.fixture @@ -258,7 +221,9 @@ async def test_conversation_cache_not_configured( assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR detail = exc_info.value.detail assert isinstance(detail, dict) - assert "Conversation cache not configured" in detail["response"] + detail_dict = cast(dict[str, Any], detail) + response_text = detail_dict.get("response", "") + assert "Conversation cache not configured" in response_text @pytest.mark.asyncio async def test_successful_retrieval( @@ -269,9 +234,9 @@ async def test_successful_retrieval( mocker.patch("app.endpoints.conversations_v2.configuration", mock_configuration) timestamp_str = "2024-01-01T00:00:00Z" - timestamp_dt = datetime.fromisoformat( - timestamp_str.replace("Z", "+00:00") - ).replace(tzinfo=timezone.utc) + timestamp_dt = datetime.fromisoformat(timestamp_str).replace( + tzinfo=timezone.utc + ) timestamp = timestamp_dt.timestamp() mock_configuration.conversation_cache.list.return_value = [ @@ -335,20 +300,6 @@ async def test_with_skip_userid_check( "mock_user_id", True ) - @pytest.mark.asyncio - async def test_malformed_auth_object( - self, mocker: MockerFixture, mock_configuration: MockType - ) -> None: - """Test the endpoint with a malformed auth object.""" - mock_authorization_resolvers(mocker) - mocker.patch("app.endpoints.conversations_v2.configuration", mock_configuration) - - with pytest.raises(IndexError): - await get_conversations_list_endpoint_handler( - request=mocker.Mock(), - auth=(), # Malformed auth object - ) - class TestGetConversationEndpoint: """Test cases for the GET /conversations/{conversation_id} endpoint.""" @@ -421,7 +372,9 @@ async def test_conversation_cache_not_configured( assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR detail = exc_info.value.detail assert isinstance(detail, dict) - assert "Conversation cache not configured" in detail["response"] + detail_dict = cast(dict[str, Any], detail) + response_text = detail_dict.get("response", "") + assert "Conversation cache not configured" in response_text @pytest.mark.asyncio async def test_conversation_not_found( @@ -473,7 +426,7 @@ async def test_successful_retrieval( assert response is not None assert response.conversation_id == VALID_CONVERSATION_ID assert len(response.chat_history) == 1 - assert response.chat_history[0]["messages"][0]["content"] == "query" + assert response.chat_history[0].messages[0].content == "query" @pytest.mark.asyncio async def test_with_skip_userid_check( @@ -508,22 +461,6 @@ async def test_with_skip_userid_check( "mock_user_id", VALID_CONVERSATION_ID, True ) - @pytest.mark.asyncio - async def test_malformed_auth_object( - self, mocker: MockerFixture, mock_configuration: MockType - ) -> None: - """Test the endpoint with a malformed auth object.""" - mock_authorization_resolvers(mocker) - mocker.patch("app.endpoints.conversations_v2.configuration", mock_configuration) - mocker.patch("app.endpoints.conversations_v2.check_suid", return_value=True) - - with pytest.raises(IndexError): - await get_conversation_endpoint_handler( - request=mocker.Mock(), - conversation_id=VALID_CONVERSATION_ID, - auth=(), # Malformed auth object - ) - class TestDeleteConversationEndpoint: """Test cases for the DELETE /conversations/{conversation_id} endpoint.""" @@ -585,28 +522,9 @@ async def test_conversation_cache_not_configured( assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR detail = exc_info.value.detail assert isinstance(detail, dict) - assert "Conversation cache not configured" in detail["response"] - - @pytest.mark.asyncio - async def test_conversation_not_found( - self, mocker: MockerFixture, mock_configuration: MockType - ) -> None: - """Test the endpoint when conversation does not exist.""" - mock_authorization_resolvers(mocker) - mocker.patch("app.endpoints.conversations_v2.configuration", mock_configuration) - mocker.patch("app.endpoints.conversations_v2.check_suid", return_value=True) - mock_configuration.conversation_cache.delete.return_value = False - - response = await delete_conversation_endpoint_handler( - request=mocker.Mock(), - conversation_id=VALID_CONVERSATION_ID, - auth=MOCK_AUTH, - ) - - assert response is not None - assert response.conversation_id == VALID_CONVERSATION_ID - assert response.success is True - assert response.response == "Conversation cannot be deleted" + detail_dict = cast(dict[str, Any], detail) + response_text = detail_dict.get("response", "") + assert "Conversation cache not configured" in response_text @pytest.mark.asyncio async def test_successful_deletion( @@ -616,9 +534,6 @@ async def test_successful_deletion( mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations_v2.configuration", mock_configuration) mocker.patch("app.endpoints.conversations_v2.check_suid", return_value=True) - mock_configuration.conversation_cache.list.return_value = [ - mocker.Mock(conversation_id=VALID_CONVERSATION_ID) - ] mock_configuration.conversation_cache.delete.return_value = True response = await delete_conversation_endpoint_handler( @@ -636,13 +551,10 @@ async def test_successful_deletion( async def test_unsuccessful_deletion( self, mocker: MockerFixture, mock_configuration: MockType ) -> None: - """Test unsuccessful deletion of a conversation.""" + """Test unsuccessful deletion when delete returns False.""" mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations_v2.configuration", mock_configuration) mocker.patch("app.endpoints.conversations_v2.check_suid", return_value=True) - mock_configuration.conversation_cache.list.return_value = [ - mocker.Mock(conversation_id=VALID_CONVERSATION_ID) - ] mock_configuration.conversation_cache.delete.return_value = False response = await delete_conversation_endpoint_handler( @@ -674,9 +586,6 @@ async def test_with_skip_userid_check( mock_authorization_resolvers(mocker) mocker.patch("app.endpoints.conversations_v2.configuration", mock_configuration) mocker.patch("app.endpoints.conversations_v2.check_suid", return_value=True) - mock_configuration.conversation_cache.list.return_value = [ - mocker.Mock(conversation_id=VALID_CONVERSATION_ID) - ] mock_auth_with_skip = ("mock_user_id", "mock_username", True, "mock_token") await delete_conversation_endpoint_handler( @@ -689,22 +598,6 @@ async def test_with_skip_userid_check( "mock_user_id", VALID_CONVERSATION_ID, True ) - @pytest.mark.asyncio - async def test_malformed_auth_object( - self, mocker: MockerFixture, mock_configuration: MockType - ) -> None: - """Test the endpoint with a malformed auth object.""" - mock_authorization_resolvers(mocker) - mocker.patch("app.endpoints.conversations_v2.configuration", mock_configuration) - mocker.patch("app.endpoints.conversations_v2.check_suid", return_value=True) - - with pytest.raises(IndexError): - await delete_conversation_endpoint_handler( - request=mocker.Mock(), - conversation_id=VALID_CONVERSATION_ID, - auth=(), # Malformed auth object - ) - class TestUpdateConversationEndpoint: """Test cases for the PUT /conversations/{conversation_id} endpoint.""" @@ -777,7 +670,9 @@ async def test_conversation_cache_not_configured( assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR detail = exc_info.value.detail assert isinstance(detail, dict) - assert "Conversation cache not configured" in detail["response"] # type: ignore + detail_dict = cast(dict[str, Any], detail) + response_text = detail_dict.get("response", "") + assert "Conversation cache not configured" in response_text # type: ignore @pytest.mark.asyncio async def test_conversation_not_found( @@ -856,20 +751,3 @@ async def test_with_skip_userid_check( mock_configuration.conversation_cache.set_topic_summary.assert_called_once_with( "mock_user_id", VALID_CONVERSATION_ID, "New topic summary", True ) - - @pytest.mark.asyncio - async def test_malformed_auth_object( - self, mocker: MockerFixture, mock_configuration: MockType - ) -> None: - """Test the endpoint with a malformed auth object.""" - mock_authorization_resolvers(mocker) - mocker.patch("app.endpoints.conversations_v2.configuration", mock_configuration) - mocker.patch("app.endpoints.conversations_v2.check_suid", return_value=True) - update_request = ConversationUpdateRequest(topic_summary="New topic summary") - - with pytest.raises(IndexError): - await update_conversation_endpoint_handler( - conversation_id=VALID_CONVERSATION_ID, - update_request=update_request, - auth=(), # Malformed auth object - ) diff --git a/tests/unit/app/test_routers.py b/tests/unit/app/test_routers.py index 0da8c15f1..1a86ecf9b 100644 --- a/tests/unit/app/test_routers.py +++ b/tests/unit/app/test_routers.py @@ -8,7 +8,7 @@ from app.endpoints import ( conversations_v2, - conversations_v3, + conversations_v1, root, info, models, @@ -124,7 +124,7 @@ def test_include_routers() -> None: assert authorized.router in app.get_routers() # assert conversations.router in app.get_routers() assert conversations_v2.router in app.get_routers() - assert conversations_v3.router in app.get_routers() + assert conversations_v1.router in app.get_routers() assert metrics.router in app.get_routers() assert rlsapi_v1.router in app.get_routers() assert a2a.router in app.get_routers() @@ -163,7 +163,7 @@ def test_check_prefixes() -> None: assert app.get_router_prefix(authorized.router) == "" # assert app.get_router_prefix(conversations.router) == "/v1" assert app.get_router_prefix(conversations_v2.router) == "/v2" - assert app.get_router_prefix(conversations_v3.router) == "/v1" + assert app.get_router_prefix(conversations_v1.router) == "/v1" assert app.get_router_prefix(metrics.router) == "" assert app.get_router_prefix(rlsapi_v1.router) == "/v1" assert app.get_router_prefix(a2a.router) == "" diff --git a/tests/unit/models/responses/test_successful_responses.py b/tests/unit/models/responses/test_successful_responses.py index 117102e30..4a3456e5e 100644 --- a/tests/unit/models/responses/test_successful_responses.py +++ b/tests/unit/models/responses/test_successful_responses.py @@ -592,6 +592,10 @@ def test_constructor(self) -> None: {"content": "Hello", "type": "user"}, {"content": "Hi there!", "type": "assistant"}, ], + "tool_calls": [], + "tool_results": [], + "provider": "google", + "model": "gemini-2.0-flash-exp", "started_at": "2024-01-01T00:01:00Z", "completed_at": "2024-01-01T00:01:05Z", } @@ -602,7 +606,9 @@ def test_constructor(self) -> None: ) assert isinstance(response, AbstractSuccessfulResponse) assert response.conversation_id == "123e4567-e89b-12d3-a456-426614174000" - assert response.chat_history == chat_history + # Convert ConversationTurn objects to dicts for comparison + actual_history = [turn.model_dump() for turn in response.chat_history] + assert actual_history == chat_history def test_empty_chat_history(self) -> None: """Test ConversationResponse with empty chat_history.""" diff --git a/tests/unit/utils/test_conversations.py b/tests/unit/utils/test_conversations.py new file mode 100644 index 000000000..e4120f145 --- /dev/null +++ b/tests/unit/utils/test_conversations.py @@ -0,0 +1,722 @@ +"""Unit tests for conversation utility functions.""" + +from datetime import datetime, UTC +from typing import Any + +import pytest +from pytest_mock import MockerFixture + +from constants import DEFAULT_RAG_TOOL +from models.database.conversations import UserTurn +from utils.conversations import ( + _build_tool_call_summary_from_item, + _extract_text_from_content, + build_conversation_turns_from_items, +) +from utils.types import ToolCallSummary + +# Default conversation start time for tests +DEFAULT_CONVERSATION_START_TIME = datetime.fromisoformat( + "2024-01-01T00:00:00Z" +).replace(tzinfo=UTC) + + +@pytest.fixture(name="create_mock_user_turn") +def create_mock_user_turn_fixture(mocker: MockerFixture) -> Any: + """Factory fixture to create mock UserTurn objects. + + Args: + mocker: Mocker fixture + + Returns: + Function that creates a mock UserTurn with specified attributes + """ + + def _create( + turn_number: int = 1, + started_at: str = "2024-01-01T00:01:00Z", + completed_at: str = "2024-01-01T00:01:05Z", + provider: str = "google", + model: str = "gemini-2.0-flash-exp", + ) -> Any: + mock_turn = mocker.Mock(spec=UserTurn) + mock_turn.turn_number = turn_number + mock_turn.started_at = datetime.fromisoformat(started_at).replace(tzinfo=UTC) + mock_turn.completed_at = datetime.fromisoformat(completed_at).replace( + tzinfo=UTC + ) + mock_turn.provider = provider + mock_turn.model = model + return mock_turn + + return _create + + +class TestExtractTextFromContent: + """Test cases for _extract_text_from_content function.""" + + def test_string_input(self) -> None: + """Test extracting text from string input.""" + content = "Simple text message" + result = _extract_text_from_content(content) + + assert result == "Simple text message" + + def test_composed_input(self) -> None: + """Test extracting text from composed (list) input.""" + + # Create simple objects with text and refusal attributes + class TextPart: # pylint: disable=too-few-public-methods + """Helper class for testing text extraction.""" + + def __init__(self, text: str) -> None: + self.text = text + + class RefusalPart: # pylint: disable=too-few-public-methods + """Helper class for testing refusal extraction.""" + + def __init__(self, refusal: str) -> None: + self.refusal = refusal + + # Create composed content with various types + content = [ + "String part", + TextPart("First part"), + RefusalPart("Refusal message"), + {"text": "Dict text"}, + {"refusal": "Dict refusal"}, + ] + + result = _extract_text_from_content(content) + + assert result == "String partFirst partRefusal messageDict textDict refusal" + + +class TestBuildToolCallSummaryFromItem: + """Test cases for _build_tool_call_summary_from_item function.""" + + def test_function_call_item(self, mocker: MockerFixture) -> None: + """Test parsing a function_call item.""" + mock_item = mocker.Mock() + mock_item.type = "function_call" + mock_item.call_id = "call_123" + mock_item.name = "test_function" + mock_item.arguments = '{"arg1": "value1"}' + + tool_call, tool_result = _build_tool_call_summary_from_item(mock_item) + + assert tool_call is not None + assert isinstance(tool_call, ToolCallSummary) + assert tool_call.id == "call_123" + assert tool_call.name == "test_function" + assert tool_call.type == "function_call" + assert tool_result is None + + def test_file_search_call_with_results(self, mocker: MockerFixture) -> None: + """Test parsing a file_search_call item with results.""" + mock_result = mocker.Mock() + mock_result.model_dump.return_value = {"file": "test.txt", "content": "test"} + + mock_item = mocker.Mock() + mock_item.type = "file_search_call" + mock_item.id = "file_search_123" + mock_item.queries = ["query1", "query2"] + mock_item.status = "success" + mock_item.results = [mock_result] + + tool_call, tool_result = _build_tool_call_summary_from_item(mock_item) + + assert tool_call is not None + assert tool_call.id == "file_search_123" + assert tool_call.name == DEFAULT_RAG_TOOL + assert tool_call.type == "file_search_call" + assert tool_call.args == {"queries": ["query1", "query2"]} + + assert tool_result is not None + assert tool_result.id == "file_search_123" + assert tool_result.status == "success" + assert tool_result.type == "file_search_call" + assert tool_result.round == 1 + assert "results" in tool_result.content + + def test_file_search_call_without_results(self, mocker: MockerFixture) -> None: + """Test parsing a file_search_call item without results.""" + mock_item = mocker.Mock() + mock_item.type = "file_search_call" + mock_item.id = "file_search_123" + mock_item.queries = ["query1"] + mock_item.status = "success" + mock_item.results = None + + tool_call, tool_result = _build_tool_call_summary_from_item(mock_item) + + assert tool_call is not None + assert tool_result is not None + assert tool_result.content == "" + + def test_web_search_call(self, mocker: MockerFixture) -> None: + """Test parsing a web_search_call item.""" + mock_item = mocker.Mock() + mock_item.type = "web_search_call" + mock_item.id = "web_search_123" + mock_item.status = "success" + + tool_call, tool_result = _build_tool_call_summary_from_item(mock_item) + + assert tool_call is not None + assert tool_call.id == "web_search_123" + assert tool_call.name == "web_search" + assert tool_call.type == "web_search_call" + assert tool_call.args == {} + + assert tool_result is not None + assert tool_result.id == "web_search_123" + assert tool_result.status == "success" + assert tool_result.type == "web_search_call" + assert tool_result.content == "" + assert tool_result.round == 1 + + def test_mcp_call_with_error(self, mocker: MockerFixture) -> None: + """Test parsing an mcp_call item with error.""" + mock_item = mocker.Mock() + mock_item.type = "mcp_call" + mock_item.id = "mcp_123" + mock_item.name = "test_mcp_tool" + mock_item.arguments = '{"param": "value"}' + mock_item.server_label = "test_server" + mock_item.error = "Error occurred" + mock_item.output = None + + tool_call, tool_result = _build_tool_call_summary_from_item(mock_item) + + assert tool_call is not None + assert tool_call.id == "mcp_123" + assert tool_call.name == "test_mcp_tool" + assert tool_call.type == "mcp_call" + assert "server_label" in tool_call.args + assert tool_call.args["server_label"] == "test_server" + + assert tool_result is not None + assert tool_result.status == "failure" + assert tool_result.content == "Error occurred" + + def test_mcp_call_with_output(self, mocker: MockerFixture) -> None: + """Test parsing an mcp_call item with output.""" + mock_item = mocker.Mock() + mock_item.type = "mcp_call" + mock_item.id = "mcp_123" + mock_item.name = "test_mcp_tool" + mock_item.arguments = '{"param": "value"}' + mock_item.server_label = "test_server" + mock_item.error = None + mock_item.output = "Success output" + + _, tool_result = _build_tool_call_summary_from_item(mock_item) + + assert tool_result is not None + assert tool_result.status == "success" + assert tool_result.content == "Success output" + + def test_mcp_call_without_server_label(self, mocker: MockerFixture) -> None: + """Test parsing an mcp_call item without server_label.""" + mock_item = mocker.Mock() + mock_item.type = "mcp_call" + mock_item.id = "mcp_123" + mock_item.name = "test_mcp_tool" + mock_item.arguments = '{"param": "value"}' + mock_item.server_label = None + mock_item.error = None + mock_item.output = "output" + + tool_call, _ = _build_tool_call_summary_from_item(mock_item) + + assert tool_call is not None + assert "server_label" not in tool_call.args + + def test_mcp_list_tools(self, mocker: MockerFixture) -> None: + """Test parsing an mcp_list_tools item.""" + mock_tool = mocker.Mock() + mock_tool.name = "tool1" + mock_tool.description = "Description" + mock_tool.input_schema = {"type": "object"} + + mock_item = mocker.Mock() + mock_item.type = "mcp_list_tools" + mock_item.id = "list_tools_123" + mock_item.server_label = "test_server" + mock_item.tools = [mock_tool] + + tool_call, tool_result = _build_tool_call_summary_from_item(mock_item) + + assert tool_call is not None + assert tool_call.id == "list_tools_123" + assert tool_call.name == "mcp_list_tools" + assert tool_call.type == "mcp_list_tools" + assert tool_call.args == {"server_label": "test_server"} + + assert tool_result is not None + assert tool_result.status == "success" + assert "tools" in tool_result.content + assert "test_server" in tool_result.content + + def test_mcp_approval_request(self, mocker: MockerFixture) -> None: + """Test parsing an mcp_approval_request item.""" + mock_item = mocker.Mock() + mock_item.type = "mcp_approval_request" + mock_item.id = "approval_123" + mock_item.name = "approve_action" + mock_item.arguments = '{"action": "delete"}' + + tool_call, tool_result = _build_tool_call_summary_from_item(mock_item) + + assert tool_call is not None + assert tool_call.id == "approval_123" + assert tool_call.name == "approve_action" + assert tool_call.type == "tool_call" + assert tool_result is None + + def test_mcp_approval_response_approved(self, mocker: MockerFixture) -> None: + """Test parsing an mcp_approval_response item with approval.""" + mock_item = mocker.Mock() + mock_item.type = "mcp_approval_response" + mock_item.approval_request_id = "approval_123" + mock_item.approve = True + mock_item.reason = "Looks good" + + tool_call, tool_result = _build_tool_call_summary_from_item(mock_item) + + assert tool_call is None + assert tool_result is not None + assert tool_result.id == "approval_123" + assert tool_result.status == "success" + assert tool_result.type == "mcp_approval_response" + assert "reason" in tool_result.content + + def test_mcp_approval_response_denied(self, mocker: MockerFixture) -> None: + """Test parsing an mcp_approval_response item with denial.""" + mock_item = mocker.Mock() + mock_item.type = "mcp_approval_response" + mock_item.approval_request_id = "approval_123" + mock_item.approve = False + mock_item.reason = "Not allowed" + + _, tool_result = _build_tool_call_summary_from_item(mock_item) + + assert tool_result is not None + assert tool_result.status == "denied" + + def test_mcp_approval_response_without_reason(self, mocker: MockerFixture) -> None: + """Test parsing an mcp_approval_response item without reason.""" + mock_item = mocker.Mock() + mock_item.type = "mcp_approval_response" + mock_item.approval_request_id = "approval_123" + mock_item.approve = True + mock_item.reason = None + + _, tool_result = _build_tool_call_summary_from_item(mock_item) + + assert tool_result is not None + assert tool_result.content == "{}" + + def test_function_call_output(self, mocker: MockerFixture) -> None: + """Test parsing a function_call_output item.""" + mock_item = mocker.Mock() + mock_item.type = "function_call_output" + mock_item.call_id = "call_123" + mock_item.status = "success" + mock_item.output = "Function result" + + tool_call, tool_result = _build_tool_call_summary_from_item(mock_item) + + assert tool_call is None + assert tool_result is not None + assert tool_result.id == "call_123" + assert tool_result.status == "success" + assert tool_result.content == "Function result" + assert tool_result.type == "function_call_output" + assert tool_result.round == 1 + + def test_function_call_output_without_status(self, mocker: MockerFixture) -> None: + """Test parsing a function_call_output item without status.""" + mock_item = mocker.Mock() + mock_item.type = "function_call_output" + mock_item.call_id = "call_123" + mock_item.status = None + mock_item.output = "Function result" + + _, tool_result = _build_tool_call_summary_from_item(mock_item) + + assert tool_result is not None + assert tool_result.status == "success" # Defaults to "success" + + def test_unknown_item_type(self, mocker: MockerFixture) -> None: + """Test parsing an unknown item type.""" + mock_item = mocker.Mock() + mock_item.type = "unknown_type" + + tool_call, tool_result = _build_tool_call_summary_from_item(mock_item) + + assert tool_call is None + assert tool_result is None + + def test_item_without_type_attribute(self, mocker: MockerFixture) -> None: + """Test parsing an item without type attribute.""" + mock_item = mocker.Mock(spec=[]) + # Don't set type attribute + + tool_call, tool_result = _build_tool_call_summary_from_item(mock_item) + + assert tool_call is None + assert tool_result is None + + +class TestBuildConversationTurnsFromItems: + """Test cases for build_conversation_turns_from_items function.""" + + def test_empty_items(self) -> None: + """Test with empty items list.""" + result = build_conversation_turns_from_items( + [], [], DEFAULT_CONVERSATION_START_TIME + ) + + assert not result + + def test_single_turn_user_and_assistant( + self, + mocker: MockerFixture, + create_mock_user_turn: Any, + ) -> None: + """Test building a single turn with user and assistant messages.""" + mock_user_msg = mocker.Mock() + mock_user_msg.type = "message" + mock_user_msg.role = "user" + mock_user_msg.content = "Hello" + + mock_assistant_msg = mocker.Mock() + mock_assistant_msg.type = "message" + mock_assistant_msg.role = "assistant" + mock_assistant_msg.content = "Hi there!" + + items = [mock_user_msg, mock_assistant_msg] + turns_metadata = [create_mock_user_turn(turn_number=1)] + + result = build_conversation_turns_from_items( + items, turns_metadata, DEFAULT_CONVERSATION_START_TIME + ) + + assert len(result) == 1 + turn = result[0] + assert len(turn.messages) == 2 + assert turn.messages[0].type == "user" + assert turn.messages[0].content == "Hello" + assert turn.messages[1].type == "assistant" + assert turn.messages[1].content == "Hi there!" + assert turn.tool_calls == [] + assert turn.tool_results == [] + + def test_multiple_turns( + self, mocker: MockerFixture, create_mock_user_turn: Any + ) -> None: + """Test building multiple turns.""" + items = [ + mocker.Mock(type="message", role="user", content="Question 1"), + mocker.Mock(type="message", role="assistant", content="Answer 1"), + mocker.Mock(type="message", role="user", content="Question 2"), + mocker.Mock(type="message", role="assistant", content="Answer 2"), + ] + turns_metadata = [ + create_mock_user_turn(turn_number=1), + create_mock_user_turn(turn_number=2), + ] + + result = build_conversation_turns_from_items( + items, turns_metadata, DEFAULT_CONVERSATION_START_TIME + ) + + assert len(result) == 2 + assert result[0].messages[0].content == "Question 1" + assert result[0].messages[1].content == "Answer 1" + assert result[1].messages[0].content == "Question 2" + assert result[1].messages[1].content == "Answer 2" + + def test_turn_with_tool_calls( + self, + mocker: MockerFixture, + create_mock_user_turn: Any, + ) -> None: + """Test building a turn with tool calls.""" + mock_function_call = mocker.Mock() + mock_function_call.type = "function_call" + mock_function_call.call_id = "call_1" + mock_function_call.name = "test_tool" + mock_function_call.arguments = '{"arg": "value"}' + + items = [ + mocker.Mock(type="message", role="user", content="Use tool"), + mock_function_call, + mocker.Mock(type="message", role="assistant", content="Done"), + ] + turns_metadata = [create_mock_user_turn(turn_number=1)] + + result = build_conversation_turns_from_items( + items, turns_metadata, DEFAULT_CONVERSATION_START_TIME + ) + + assert len(result) == 1 + assert len(result[0].tool_calls) == 1 + assert result[0].tool_calls[0].name == "test_tool" + + def test_turn_with_tool_results( + self, + mocker: MockerFixture, + create_mock_user_turn: Any, + ) -> None: + """Test building a turn with tool results.""" + mock_function_output = mocker.Mock() + mock_function_output.type = "function_call_output" + mock_function_output.call_id = "call_1" + mock_function_output.status = "success" + mock_function_output.output = "Result" + + items = [ + mocker.Mock(type="message", role="user", content="Use tool"), + mock_function_output, + mocker.Mock(type="message", role="assistant", content="Done"), + ] + turns_metadata = [create_mock_user_turn(turn_number=1)] + + result = build_conversation_turns_from_items( + items, turns_metadata, DEFAULT_CONVERSATION_START_TIME + ) + + assert len(result) == 1 + assert len(result[0].tool_results) == 1 + assert result[0].tool_results[0].status == "success" + + def test_turn_with_both_tool_calls_and_results( + self, + mocker: MockerFixture, + create_mock_user_turn: Any, + ) -> None: + """Test building a turn with both tool calls and results.""" + mock_function_call = mocker.Mock() + mock_function_call.type = "function_call" + mock_function_call.call_id = "call_1" + mock_function_call.name = "test_tool" + mock_function_call.arguments = "{}" + + mock_function_output = mocker.Mock() + mock_function_output.type = "function_call_output" + mock_function_output.call_id = "call_1" + mock_function_output.status = "success" + mock_function_output.output = "Result" + + items = [ + mocker.Mock(type="message", role="user", content="Use tool"), + mock_function_call, + mock_function_output, + mocker.Mock(type="message", role="assistant", content="Done"), + ] + turns_metadata = [create_mock_user_turn(turn_number=1)] + + result = build_conversation_turns_from_items( + items, turns_metadata, DEFAULT_CONVERSATION_START_TIME + ) + + assert len(result) == 1 + assert len(result[0].tool_calls) == 1 + assert len(result[0].tool_results) == 1 + + def test_turn_with_file_search_tool( + self, + mocker: MockerFixture, + create_mock_user_turn: Any, + ) -> None: + """Test building a turn with file_search_call tool.""" + mock_file_search = mocker.Mock() + mock_file_search.type = "file_search_call" + mock_file_search.id = "file_1" + mock_file_search.queries = ["query1"] + mock_file_search.status = "success" + mock_file_search.results = None + + items = [ + mocker.Mock(type="message", role="user", content="Search files"), + mock_file_search, + mocker.Mock(type="message", role="assistant", content="Found files"), + ] + turns_metadata = [create_mock_user_turn(turn_number=1)] + + result = build_conversation_turns_from_items( + items, turns_metadata, DEFAULT_CONVERSATION_START_TIME + ) + + assert len(result) == 1 + assert len(result[0].tool_calls) == 1 + assert len(result[0].tool_results) == 1 + assert result[0].tool_calls[0].name == DEFAULT_RAG_TOOL + + def test_turn_with_multiple_assistant_messages( + self, + mocker: MockerFixture, + create_mock_user_turn: Any, + ) -> None: + """Test building a turn with multiple assistant messages.""" + items = [ + mocker.Mock(type="message", role="user", content="Question"), + mocker.Mock(type="message", role="assistant", content="Part 1"), + mocker.Mock(type="message", role="assistant", content="Part 2"), + ] + turns_metadata = [create_mock_user_turn(turn_number=1)] + + result = build_conversation_turns_from_items( + items, turns_metadata, DEFAULT_CONVERSATION_START_TIME + ) + + assert len(result) == 1 + assert len(result[0].messages) == 3 + assert result[0].messages[0].type == "user" + assert result[0].messages[1].type == "assistant" + assert result[0].messages[2].type == "assistant" + + def test_turn_metadata_used_correctly( + self, + mocker: MockerFixture, + create_mock_user_turn: Any, + ) -> None: + """Test that turn metadata (provider, model, timestamps) is used correctly.""" + items = [ + mocker.Mock(type="message", role="user", content="Test"), + mocker.Mock(type="message", role="assistant", content="Response"), + ] + turns_metadata = [ + create_mock_user_turn( + turn_number=1, + provider="openai", + model="gpt-4", + started_at="2024-01-01T10:00:00Z", + completed_at="2024-01-01T10:00:05Z", + ) + ] + + result = build_conversation_turns_from_items( + items, turns_metadata, DEFAULT_CONVERSATION_START_TIME + ) + + assert len(result) == 1 + turn = result[0] + assert turn.provider == "openai" + assert turn.model == "gpt-4" + assert turn.started_at == "2024-01-01T10:00:00Z" + assert turn.completed_at == "2024-01-01T10:00:05Z" + + def test_turn_with_only_tool_items_no_messages( + self, + mocker: MockerFixture, + create_mock_user_turn: Any, + ) -> None: + """Test building a turn with only tool items (no messages).""" + mock_function_call = mocker.Mock() + mock_function_call.type = "function_call" + mock_function_call.call_id = "call_1" + mock_function_call.name = "test_tool" + mock_function_call.arguments = "{}" + + items = [mock_function_call] + turns_metadata = [create_mock_user_turn(turn_number=1)] + + result = build_conversation_turns_from_items( + items, turns_metadata, DEFAULT_CONVERSATION_START_TIME + ) + + # Should still create a turn if there are tool calls/results + assert len(result) == 1 + assert len(result[0].messages) == 0 + assert len(result[0].tool_calls) == 1 + + def test_multiple_turns_with_tools( + self, + mocker: MockerFixture, + create_mock_user_turn: Any, + ) -> None: + """Test building multiple turns where some have tools.""" + mock_function_call = mocker.Mock() + mock_function_call.type = "function_call" + mock_function_call.call_id = "call_1" + mock_function_call.name = "test_tool" + mock_function_call.arguments = "{}" + + items = [ + mocker.Mock(type="message", role="user", content="Question 1"), + mocker.Mock(type="message", role="assistant", content="Answer 1"), + mocker.Mock(type="message", role="user", content="Question 2"), + mock_function_call, + mocker.Mock(type="message", role="assistant", content="Answer 2"), + ] + turns_metadata = [ + create_mock_user_turn(turn_number=1), + create_mock_user_turn(turn_number=2), + ] + + result = build_conversation_turns_from_items( + items, turns_metadata, DEFAULT_CONVERSATION_START_TIME + ) + + assert len(result) == 2 + assert len(result[0].tool_calls) == 0 + assert len(result[1].tool_calls) == 1 + + def test_turn_indexing_with_metadata( + self, + mocker: MockerFixture, + create_mock_user_turn: Any, + ) -> None: + """Test that turn metadata is correctly indexed by turn number.""" + items = [ + mocker.Mock(type="message", role="user", content="Q1"), + mocker.Mock(type="message", role="assistant", content="A1"), + mocker.Mock(type="message", role="user", content="Q2"), + mocker.Mock(type="message", role="assistant", content="A2"), + mocker.Mock(type="message", role="user", content="Q3"), + mocker.Mock(type="message", role="assistant", content="A3"), + ] + turns_metadata = [ + create_mock_user_turn(turn_number=1, provider="provider1"), + create_mock_user_turn(turn_number=2, provider="provider2"), + create_mock_user_turn(turn_number=3, provider="provider3"), + ] + + result = build_conversation_turns_from_items( + items, turns_metadata, DEFAULT_CONVERSATION_START_TIME + ) + + assert len(result) == 3 + assert result[0].provider == "provider1" + assert result[1].provider == "provider2" + assert result[2].provider == "provider3" + + def test_legacy_conversation_without_metadata(self, mocker: MockerFixture) -> None: + """Test building turns for legacy conversation without stored turn metadata.""" + # Legacy conversations have items but no turns_metadata + items = [ + mocker.Mock(type="message", role="user", content="Question"), + mocker.Mock(type="message", role="assistant", content="Answer"), + ] + turns_metadata: list[UserTurn] = [] # Empty metadata for legacy conversation + conversation_start_time = datetime.fromisoformat( + "2024-01-01T10:00:00Z" + ).replace(tzinfo=UTC) + + result = build_conversation_turns_from_items( + items, turns_metadata, conversation_start_time + ) + + assert len(result) == 1 + turn = result[0] + assert len(turn.messages) == 2 + # Legacy conversations should use dummy metadata with N/A values + assert turn.provider == "N/A" + assert turn.model == "N/A" + # Timestamps should match conversation start time + assert turn.started_at == "2024-01-01T10:00:00Z" + assert turn.completed_at == "2024-01-01T10:00:00Z"