diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 000000000..1d3fa4e21 --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,13 @@ +{ + "permissions": { + "allow": [ + "Bash(git add:*)", + "Bash(git push:*)", + "Bash(git pull:*)", + "Bash(echo:*)", + "Bash(uv run -m pytest:*)", + "Bash(git commit:*)", + "Bash(git merge:*)" + ] + } +} diff --git a/.github/workflows/claude-code-review.yml b/.github/workflows/claude-code-review.yml new file mode 100644 index 000000000..4f6145beb --- /dev/null +++ b/.github/workflows/claude-code-review.yml @@ -0,0 +1,44 @@ +name: Claude Code Review + +on: + pull_request: + types: [opened, synchronize, ready_for_review, reopened] + # Optional: Only run on specific file changes + # paths: + # - "src/**/*.ts" + # - "src/**/*.tsx" + # - "src/**/*.js" + # - "src/**/*.jsx" + +jobs: + claude-review: + # Optional: Filter by PR author + # if: | + # github.event.pull_request.user.login == 'external-contributor' || + # github.event.pull_request.user.login == 'new-developer' || + # github.event.pull_request.author_association == 'FIRST_TIME_CONTRIBUTOR' + + runs-on: ubuntu-latest + permissions: + contents: read + pull-requests: read + issues: read + id-token: write + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 1 + + - name: Run Claude Code Review + id: claude-review + uses: anthropics/claude-code-action@v1 + with: + anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} + plugin_marketplaces: 'https://github.com/anthropics/claude-code.git' + plugins: 'code-review@claude-code-plugins' + prompt: '/code-review:code-review ${{ github.repository }}/pull/${{ github.event.pull_request.number }}' + # See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md + # or https://code.claude.com/docs/en/cli-reference for available options + diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml new file mode 100644 index 000000000..79fe05647 --- /dev/null +++ b/.github/workflows/claude.yml @@ -0,0 +1,50 @@ +name: Claude Code + +on: + issue_comment: + types: [created] + pull_request_review_comment: + types: [created] + issues: + types: [opened, assigned] + pull_request_review: + types: [submitted] + +jobs: + claude: + if: | + (github.event_name == 'issue_comment' && contains(github.event.comment.body, '@claude')) || + (github.event_name == 'pull_request_review_comment' && contains(github.event.comment.body, '@claude')) || + (github.event_name == 'pull_request_review' && contains(github.event.review.body, '@claude')) || + (github.event_name == 'issues' && (contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude'))) + runs-on: ubuntu-latest + permissions: + contents: read + pull-requests: read + issues: read + id-token: write + actions: read # Required for Claude to read CI results on PRs + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 1 + + - name: Run Claude Code + id: claude + uses: anthropics/claude-code-action@v1 + with: + anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} + + # This is an optional setting that allows Claude to read CI results on PRs + additional_permissions: | + actions: read + + # Optional: Give a custom prompt to Claude. If this is not specified, Claude will perform the instructions specified in the comment that tagged it. + # prompt: 'Update the pull request description to include a summary of changes.' + + # Optional: Add claude_args to customize behavior and configuration + # See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md + # or https://code.claude.com/docs/en/cli-reference for available options + # claude_args: '--allowed-tools Bash(gh pr:*)' + diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..7162cce15 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,203 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +This is a Retrieval-Augmented Generation (RAG) chatbot system that answers questions about DeepLearning.AI course materials. It uses ChromaDB for vector storage, Anthropic's Claude API with tool calling, and provides a web interface for conversational queries. + +## Development Commands + +### Setup +```bash +# Install dependencies (uses uv package manager) +uv sync + +# Set up environment variables +cp .env.example .env +# Edit .env and add your ANTHROPIC_API_KEY +``` + +### Running the Application +```bash +# Quick start (recommended) +./run.sh + +# Manual start (from project root) +cd backend && uv run uvicorn app:app --reload --port 8000 + +# Access points +# - Web UI: http://localhost:8000 +# - API docs: http://localhost:8000/docs +``` + +### Adding Course Documents +Place text files in the `docs/` folder. Files are automatically loaded on server startup. See "Document Format Requirements" below. + +## Architecture + +### RAG Pipeline Flow + +The system implements a **tool-based RAG architecture** where Claude decides when to search: + +``` +User Query → FastAPI → RAGSystem → AIGenerator → Claude API + ↓ + (Claude calls tool) + ↓ + SearchTool → VectorStore → ChromaDB + ↓ + (search results returned) + ↓ + Claude synthesizes response + ↓ + SessionManager stores history + ↓ + Return answer + sources +``` + +**Key architectural decision**: Claude has search as a *callable tool*, not an always-on feature. The system prompt instructs Claude to call `search_course_content` when needed, making searches contextual rather than automatic. + +### Dual Collection Strategy + +ChromaDB uses **two separate collections** with different purposes: + +1. **`course_catalog`** (vector_store.py:51) + - Purpose: Fuzzy matching of course names for search filtering + - Documents: Course titles only + - Metadata: Full course info (instructor, links, lesson metadata) + - IDs: Course title (serves as unique identifier) + - Usage: When user specifies a course name, semantic search finds the best match + +2. **`course_content`** (vector_store.py:52) + - Purpose: Actual semantic search of course material + - Documents: Text chunks with enriched context + - Metadata: `{course_title, lesson_number, chunk_index}` + - IDs: `"{course_title_snake_case}_{chunk_index}"` + - Usage: Primary search collection for answering queries + +This separation enables fuzzy course name matching (e.g., "MCP" → "MCP: Build Rich-Context AI Apps") before searching content. + +### Component Relationships + +**rag_system.py** is the orchestration layer that: +- Coordinates all components (VectorStore, AIGenerator, SearchTool, SessionManager) +- Manages document ingestion and deduplication +- Handles query flow from input to response + +**ai_generator.py** handles Claude API interactions: +- Builds API requests with system prompt, history, and tool definitions +- Processes tool calls from Claude +- Extracts responses and sources from Claude's output +- Uses temperature=0 for deterministic responses + +**session_manager.py** maintains conversation state: +- Thread-safe session storage with dict-based in-memory storage +- Automatically trims history to last `MAX_HISTORY` exchanges (default: 2) +- Each session tracks conversation context for multi-turn queries + +### Text Chunking Strategy + +**document_processor.py:25-91** implements sentence-aware chunking: + +1. **Sentence splitting** using regex that handles abbreviations (Mr., Dr., etc.) +2. **Chunk building** up to 800 characters per chunk +3. **Overlap calculation** - 100 characters shared between consecutive chunks by counting backwards from chunk end +4. **Context enrichment** - First chunk of each lesson prefixed with `"Lesson N content: ..."`, last lesson chunks include course title + +This preserves semantic boundaries and context across chunk boundaries. + +## Document Format Requirements + +Course documents must follow this structure: + +``` +Course Title: [title] +Course Link: [url] +Course Instructor: [name] + +Lesson 0: [title] +Lesson Link: [url] +[content...] + +Lesson 1: [title] +Lesson Link: [url] +[content...] +``` + +**Processing behavior**: +- Lines 1-3: Metadata extraction with regex matching +- Remaining lines: Parsed for `^Lesson\s+(\d+):\s*(.+)$` markers +- Content between lesson markers becomes lesson content +- Lesson links (optional) must appear immediately after lesson headers +- If no lesson markers found, entire file treated as single document + +## Configuration + +All configuration in **backend/config.py** as a dataclass: + +- `CHUNK_SIZE`: 800 characters (sentence-aware, not hard cutoff) +- `CHUNK_OVERLAP`: 100 characters between chunks +- `MAX_RESULTS`: 5 search results per query +- `MAX_HISTORY`: 2 conversation exchanges retained +- `EMBEDDING_MODEL`: "all-MiniLM-L6-v2" (384-dimensional embeddings) +- `ANTHROPIC_MODEL`: "claude-sonnet-4-20250514" +- `CHROMA_PATH`: "./chroma_db" (persistent vector storage) + +## Important Patterns + +### Document Deduplication +**rag_system.py:76** checks existing course titles before processing. If a course with the same title already exists in the vector store, it's skipped. To reload a course, clear the vector store first. + +### Tool Definition +**search_tools.py** defines the `search_course_content` tool with three parameters: +- `query` (required): What to search for +- `course_name` (optional): Fuzzy-matched against course_catalog +- `lesson_number` (optional): Filter to specific lesson + +The system prompt instructs Claude to use this tool strategically, not for every query. + +### Search Filtering +**vector_store.py:118-133** builds ChromaDB filters: +- Both course + lesson: `{"$and": [{"course_title": "..."}, {"lesson_number": N}]}` +- Course only: `{"course_title": "..."}` +- Lesson only: `{"lesson_number": N}` +- Neither: No filter (search all content) + +### Session Management +Sessions are created implicitly if no `session_id` is provided. Frontend passes `session_id` back to maintain conversation context. Sessions are stored in-memory (lost on restart). + +## Key Files + +- **app.py**: FastAPI application, startup document loading, API endpoints +- **rag_system.py**: Main orchestration, coordinates all components +- **vector_store.py**: ChromaDB wrapper, dual collection management, search logic +- **ai_generator.py**: Claude API integration, tool call handling +- **document_processor.py**: Metadata extraction, chunking algorithm +- **search_tools.py**: Tool definitions for Claude function calling +- **session_manager.py**: Conversation history management +- **config.py**: Centralized configuration +- **models.py**: Pydantic data models (Course, Lesson, CourseChunk) + +## Frontend + +Vanilla JavaScript application (frontend/) with no framework dependencies: +- **index.html**: Chat UI structure +- **script.js**: API communication, message handling +- **style.css**: Responsive styling + +Frontend communicates with backend via `/api/query` POST endpoint, receives responses with `{answer, sources, session_id}`. + +## Extending the System + +### Adding New Course Sources +Place files in `docs/` folder matching the required format. Supported extensions: `.txt`, `.pdf`, `.docx`. Server automatically loads on startup. + +### Modifying Chunking Behavior +Edit `CHUNK_SIZE` and `CHUNK_OVERLAP` in config.py. Larger chunks provide more context but reduce granularity. More overlap improves context preservation but increases storage. + +### Changing Search Results1 +Modify `MAX_RESULTS` in config.py to return more/fewer chunks per search. More results give Claude more context but increase token usage. + +### Adjusting Conversation Memory +Change `MAX_HISTORY` in config.py. Higher values retain more context but increase token costs. Each exchange = 2 messages (user + assistant). diff --git a/backend-tool-refactor.md b/backend-tool-refactor.md new file mode 100644 index 000000000..de23ae5c7 --- /dev/null +++ b/backend-tool-refactor.md @@ -0,0 +1,28 @@ +Refactor @backend/ai_generator.py to support sequential tool calling where Claude can make up to 2 tool calls in separate API rounds. + +Current behavior: +- Claude makes 1 tool call → tools are removed from API params → final response +- If Claude wants another tool call after seeing results, it can't (gets empty response) + +Desired behavior: +- Each tool call should be a separate API request where Claude can reason about previous results +- Support for complex queries requiring multiple searches for comparisons, multi-part questions, or when information from different courses/lessons is needed + +Example flow: +1. User: "Search for a course that discusses the same topic as lesson 4 of course X" +2. Claude: get course outline for course X → gets title of lesson 4 +3. Claude: uses the title to search for a course that discusses the same topic → returns course information +4. Claude: provides complete answer + +Requirements: +- Maximum 2 sequential rounds per user query +- Terminate when: (a) 2 rounds completed, (b) Claude's response has no tool_use blocks, or (c) tool call fails +- Preserve conversation context between rounds +- Handle tool execution errors gracefully + +Notes: +- Update the system prompt in @backend/ai_generator.py +- Update the test @backend/tests/test_ai_generator.py +- Write tests that verify the external behavior (API calls made, tools executed, results returned) rather than internal state details. + +Use two parallel subagents to brainstorm possible plans. Do not implement any code. diff --git a/backend/ai_generator.py b/backend/ai_generator.py index 0363ca90c..88a4b59cc 100644 --- a/backend/ai_generator.py +++ b/backend/ai_generator.py @@ -1,25 +1,43 @@ +from typing import Any, Dict, List, Optional + import anthropic -from typing import List, Optional, Dict, Any + class AIGenerator: """Handles interactions with Anthropic's Claude API for generating responses""" - + + MAX_TOOL_ROUNDS = 2 + DIRECT_RETURN_TOOLS = frozenset({"get_course_outline"}) + # Static system prompt to avoid rebuilding on each call - SYSTEM_PROMPT = """ You are an AI assistant specialized in course materials and educational content with access to a comprehensive search tool for course information. + SYSTEM_PROMPT = """ You are an AI assistant specialized in course materials and educational content with access to comprehensive tools for course information. + +Available Tools: +- **search_course_content**: Search within course materials for specific content +- **get_course_outline**: Get complete course outline with all lessons + +Tool Usage Guidelines: +- Use search_course_content for detailed questions about specific topics or lessons +- Use get_course_outline for questions about course structure, lesson lists, or "what's in this course" +- **You can make up to 2 rounds of tool calls to gather comprehensive information** + - Round 1: Initial search to gather relevant information + - Round 2: Refine or search additional context (different course, narrower lesson, related term) + - Most queries need only 1 tool call. Use a second only when the first result is insufficient. +- Synthesize tool results into accurate, fact-based responses +- If tools yield no results, state this clearly without offering alternatives -Search Tool Usage: -- Use the search tool **only** for questions about specific course content or detailed educational materials -- **One search per query maximum** -- Synthesize search results into accurate, fact-based responses -- If search yields no results, state this clearly without offering alternatives +Course Outline Responses: +When using get_course_outline: +- Return the tool output EXACTLY as formatted - do not add summaries, context, or additional information +- Present the complete structured list without modification Response Protocol: -- **General knowledge questions**: Answer using existing knowledge without searching -- **Course-specific questions**: Search first, then answer +- **General knowledge questions**: Answer using existing knowledge without tools +- **Course outline questions**: Use get_course_outline first +- **Course-specific content questions**: Use search_course_content first, then synthesize - **No meta-commentary**: - - Provide direct answers only — no reasoning process, search explanations, or question-type analysis - - Do not mention "based on the search results" - + - Provide direct answers only — no reasoning process, tool explanations, or question-type analysis + - Do not mention "based on the tool results" All responses must be: 1. **Brief, Concise and focused** - Get to the point quickly @@ -28,108 +46,171 @@ class AIGenerator: 4. **Example-supported** - Include relevant examples when they aid understanding Provide only the direct answer to what was asked. """ - + def __init__(self, api_key: str, model: str): self.client = anthropic.Anthropic(api_key=api_key) self.model = model - + # Pre-build base API parameters - self.base_params = { - "model": self.model, - "temperature": 0, - "max_tokens": 800 - } - - def generate_response(self, query: str, - conversation_history: Optional[str] = None, - tools: Optional[List] = None, - tool_manager=None) -> str: + self.base_params = {"model": self.model, "temperature": 0, "max_tokens": 800} + + def _call_api(self, **params): + """Make an Anthropic API call with standardized error handling.""" + try: + return self.client.messages.create(**params) + except anthropic.AuthenticationError as e: + raise RuntimeError(f"Anthropic API authentication failed: {e}") from e + except anthropic.APIError as e: + raise RuntimeError(f"Anthropic API error: {e}") from e + + def generate_response( + self, + query: str, + conversation_history: Optional[str] = None, + tools: Optional[List] = None, + tool_manager=None, + ) -> str: """ Generate AI response with optional tool usage and conversation context. - + Supports up to MAX_TOOL_ROUNDS sequential rounds of tool calling. + Args: query: The user's question or request conversation_history: Previous messages for context tools: Available tools the AI can use tool_manager: Manager to execute tools - + Returns: Generated response as string """ - + # Build system content efficiently - avoid string ops when possible system_content = ( f"{self.SYSTEM_PROMPT}\n\nPrevious conversation:\n{conversation_history}" - if conversation_history + if conversation_history else self.SYSTEM_PROMPT ) - - # Prepare API call parameters efficiently - api_params = { + + # Start with initial messages + messages = [{"role": "user", "content": query}] + + # Execute up to MAX_TOOL_ROUNDS rounds of tool calling + for round_num in range(self.MAX_TOOL_ROUNDS): + # Prepare API call parameters + api_params = { + **self.base_params, + "messages": messages, + "system": system_content, + } + + # Add tools if available + if tools: + api_params["tools"] = tools + api_params["tool_choice"] = {"type": "auto"} + + response = self._call_api(**api_params) + + # Handle tool execution if needed + if response.stop_reason == "tool_use" and tool_manager: + messages, should_continue, direct_result = self._handle_tool_execution( + response, messages, tool_manager + ) + if direct_result is not None: + return direct_result + if not should_continue: + break + else: + # No tool use, return direct response + return self._extract_text(response) + + # After max rounds, make final call without tools to get response + final_params = { **self.base_params, - "messages": [{"role": "user", "content": query}], - "system": system_content + "messages": messages, + "system": system_content, } - - # Add tools if available - if tools: - api_params["tools"] = tools - api_params["tool_choice"] = {"type": "auto"} - - # Get response from Claude - response = self.client.messages.create(**api_params) - - # Handle tool execution if needed - if response.stop_reason == "tool_use" and tool_manager: - return self._handle_tool_execution(response, api_params, tool_manager) - - # Return direct response - return response.content[0].text - - def _handle_tool_execution(self, initial_response, base_params: Dict[str, Any], tool_manager): + + final_response = self._call_api(**final_params) + return self._extract_text(final_response) + + @staticmethod + def _extract_text(response) -> str: + """Safely extract text from an API response, handling empty content.""" + if not response.content: + return "I'm sorry, I wasn't able to generate a response. Please try again." + for block in response.content: + if hasattr(block, "text"): + return block.text + return "I'm sorry, I wasn't able to generate a response. Please try again." + + def _handle_tool_execution(self, initial_response, messages: List, tool_manager): """ - Handle execution of tool calls and get follow-up response. - + Handle execution of tool calls and update message history. + + Executes ALL tool calls before deciding flow control. This ensures the + Anthropic API receives tool_result blocks for every tool_use block, even + if some tools fail. + Args: initial_response: The response containing tool use requests - base_params: Base API parameters + messages: Current message history tool_manager: Manager to execute tools - + Returns: - Final response text after tool execution + Tuple of (updated_messages, should_continue, direct_result) + direct_result is non-None when the tool output should be returned as-is """ - # Start with existing messages - messages = base_params["messages"].copy() - # Add AI's tool use response messages.append({"role": "assistant", "content": initial_response.content}) - - # Execute all tool calls and collect results + + # Execute ALL tool calls and collect results tool_results = [] + direct_return_result = None + has_error = False + for content_block in initial_response.content: - if content_block.type == "tool_use": + if content_block.type != "tool_use": + continue + + try: tool_result = tool_manager.execute_tool( - content_block.name, - **content_block.input + content_block.name, **content_block.input + ) + + tool_results.append( + { + "type": "tool_result", + "tool_use_id": content_block.id, + "content": tool_result, + } ) - - tool_results.append({ - "type": "tool_result", - "tool_use_id": content_block.id, - "content": tool_result - }) - - # Add tool results as single message + + # Mark outline results for direct return (but keep executing remaining tools) + if content_block.name in self.DIRECT_RETURN_TOOLS: + direct_return_result = tool_result + + except Exception as e: + has_error = True + tool_results.append( + { + "type": "tool_result", + "tool_use_id": content_block.id, + "content": f"Error: Tool execution failed - {str(e)}", + "is_error": True, + } + ) + + # Add all tool results as single message if tool_results: messages.append({"role": "user", "content": tool_results}) - - # Prepare final API call without tools - final_params = { - **self.base_params, - "messages": messages, - "system": base_params["system"] - } - - # Get final response - final_response = self.client.messages.create(**final_params) - return final_response.content[0].text \ No newline at end of file + + # Direct return takes priority (e.g. course outline) + if direct_return_result is not None: + return messages, False, direct_return_result + + # Stop rounds if any tool failed + if has_error: + return messages, False, None + + # Continue with next round + return messages, True, None diff --git a/backend/app.py b/backend/app.py index 5a69d741d..352ee097c 100644 --- a/backend/app.py +++ b/backend/app.py @@ -1,4 +1,5 @@ import warnings + warnings.filterwarnings("ignore", message="resource_tracker: There appear to be.*") from fastapi import FastAPI, HTTPException @@ -6,7 +7,7 @@ from fastapi.staticfiles import StaticFiles from fastapi.middleware.trustedhost import TrustedHostMiddleware from pydantic import BaseModel -from typing import List, Optional +from typing import List, Optional, Dict, Union import os from config import config @@ -16,10 +17,7 @@ app = FastAPI(title="Course Materials RAG System", root_path="") # Add trusted host middleware for proxy -app.add_middleware( - TrustedHostMiddleware, - allowed_hosts=["*"] -) +app.add_middleware(TrustedHostMiddleware, allowed_hosts=["*"]) # Enable CORS with proper settings for proxy app.add_middleware( @@ -34,25 +32,39 @@ # Initialize RAG system rag_system = RAGSystem(config) + # Pydantic models for request/response class QueryRequest(BaseModel): """Request model for course queries""" + query: str session_id: Optional[str] = None + class QueryResponse(BaseModel): """Response model for course queries""" + answer: str - sources: List[str] + sources: List[Union[str, Dict[str, str]]] # Support both strings and dicts session_id: str + class CourseStats(BaseModel): """Response model for course statistics""" + total_courses: int course_titles: List[str] + +class ClearSessionRequest(BaseModel): + """Request model for clearing a session""" + + session_id: str + + # API Endpoints + @app.post("/api/query", response_model=QueryResponse) async def query_documents(request: QueryRequest): """Process a query and return response with sources""" @@ -61,18 +73,15 @@ async def query_documents(request: QueryRequest): session_id = request.session_id if not session_id: session_id = rag_system.session_manager.create_session() - + # Process query using RAG system - answer, sources = rag_system.query(request.query, session_id) - - return QueryResponse( - answer=answer, - sources=sources, - session_id=session_id - ) + answer, sources, source_links = rag_system.query(request.query, session_id) + + return QueryResponse(answer=answer, sources=sources, session_id=session_id) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + @app.get("/api/courses", response_model=CourseStats) async def get_course_stats(): """Get course analytics and statistics""" @@ -80,11 +89,22 @@ async def get_course_stats(): analytics = rag_system.get_course_analytics() return CourseStats( total_courses=analytics["total_courses"], - course_titles=analytics["course_titles"] + course_titles=analytics["course_titles"], ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + +@app.post("/api/session/clear") +async def clear_session(request: ClearSessionRequest): + """Clear a conversation session""" + try: + rag_system.session_manager.clear_session(request.session_id) + return {"status": "success", "message": f"Session {request.session_id} cleared"} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @app.on_event("startup") async def startup_event(): """Load initial documents on startup""" @@ -92,11 +112,14 @@ async def startup_event(): if os.path.exists(docs_path): print("Loading initial documents...") try: - courses, chunks = rag_system.add_course_folder(docs_path, clear_existing=False) + courses, chunks = rag_system.add_course_folder( + docs_path, clear_existing=False + ) print(f"Loaded {courses} courses with {chunks} chunks") except Exception as e: print(f"Error loading documents: {e}") + # Custom static file handler with no-cache headers for development from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse @@ -113,7 +136,7 @@ async def get_response(self, path: str, scope): response.headers["Pragma"] = "no-cache" response.headers["Expires"] = "0" return response - - + + # Serve static files for the frontend -app.mount("/", StaticFiles(directory="../frontend", html=True), name="static") \ No newline at end of file +app.mount("/", DevStaticFiles(directory="../frontend", html=True), name="static") diff --git a/backend/config.py b/backend/config.py index d9f6392ef..7379e7133 100644 --- a/backend/config.py +++ b/backend/config.py @@ -5,25 +5,26 @@ # Load environment variables from .env file load_dotenv() + @dataclass class Config: """Configuration settings for the RAG system""" + # Anthropic API settings ANTHROPIC_API_KEY: str = os.getenv("ANTHROPIC_API_KEY", "") ANTHROPIC_MODEL: str = "claude-sonnet-4-20250514" - + # Embedding model settings EMBEDDING_MODEL: str = "all-MiniLM-L6-v2" - + # Document processing settings - CHUNK_SIZE: int = 800 # Size of text chunks for vector storage - CHUNK_OVERLAP: int = 100 # Characters to overlap between chunks - MAX_RESULTS: int = 5 # Maximum search results to return - MAX_HISTORY: int = 2 # Number of conversation messages to remember - + CHUNK_SIZE: int = 800 # Size of text chunks for vector storage + CHUNK_OVERLAP: int = 100 # Characters to overlap between chunks + MAX_RESULTS: int = 5 # Maximum search results to return + MAX_HISTORY: int = 2 # Number of conversation messages to remember + # Database paths CHROMA_PATH: str = "./chroma_db" # ChromaDB storage location -config = Config() - +config = Config() diff --git a/backend/document_processor.py b/backend/document_processor.py index 266e85904..32c6648ae 100644 --- a/backend/document_processor.py +++ b/backend/document_processor.py @@ -3,81 +3,84 @@ from typing import List, Tuple from models import Course, Lesson, CourseChunk + class DocumentProcessor: """Processes course documents and extracts structured information""" - + def __init__(self, chunk_size: int, chunk_overlap: int): self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap - + def read_file(self, file_path: str) -> str: """Read content from file with UTF-8 encoding""" try: - with open(file_path, 'r', encoding='utf-8') as file: + with open(file_path, "r", encoding="utf-8") as file: return file.read() except UnicodeDecodeError: # If UTF-8 fails, try with error handling - with open(file_path, 'r', encoding='utf-8', errors='ignore') as file: + with open(file_path, "r", encoding="utf-8", errors="ignore") as file: return file.read() - - def chunk_text(self, text: str) -> List[str]: """Split text into sentence-based chunks with overlap using config settings""" - + # Clean up the text - text = re.sub(r'\s+', ' ', text.strip()) # Normalize whitespace - + text = re.sub(r"\s+", " ", text.strip()) # Normalize whitespace + # Better sentence splitting that handles abbreviations # This regex looks for periods followed by whitespace and capital letters # but ignores common abbreviations - sentence_endings = re.compile(r'(? self.chunk_size and current_chunk: break - + current_chunk.append(sentence) current_size += total_addition - + # Add chunk if we have content if current_chunk: - chunks.append(' '.join(current_chunk)) - + chunks.append(" ".join(current_chunk)) + # Calculate overlap for next chunk - if hasattr(self, 'chunk_overlap') and self.chunk_overlap > 0: + if hasattr(self, "chunk_overlap") and self.chunk_overlap > 0: # Find how many sentences to overlap overlap_size = 0 overlap_sentences = 0 - + # Count backwards from end of current chunk for k in range(len(current_chunk) - 1, -1, -1): - sentence_len = len(current_chunk[k]) + (1 if k < len(current_chunk) - 1 else 0) + sentence_len = len(current_chunk[k]) + ( + 1 if k < len(current_chunk) - 1 else 0 + ) if overlap_size + sentence_len <= self.chunk_overlap: overlap_size += sentence_len overlap_sentences += 1 else: break - + # Move start position considering overlap next_start = i + len(current_chunk) - overlap_sentences i = max(next_start, i + 1) # Ensure we make progress @@ -87,14 +90,12 @@ def chunk_text(self, text: str) -> List[str]: else: # No sentences fit, move to next i += 1 - - return chunks - - + return chunks - - def process_course_document(self, file_path: str) -> Tuple[Course, List[CourseChunk]]: + def process_course_document( + self, file_path: str + ) -> Tuple[Course, List[CourseChunk]]: """ Process a course document with expected format: Line 1: Course Title: [title] @@ -104,47 +105,51 @@ def process_course_document(self, file_path: str) -> Tuple[Course, List[CourseCh """ content = self.read_file(file_path) filename = os.path.basename(file_path) - - lines = content.strip().split('\n') - + + lines = content.strip().split("\n") + # Extract course metadata from first three lines course_title = filename # Default fallback course_link = None instructor_name = "Unknown" - + # Parse course title from first line if len(lines) >= 1 and lines[0].strip(): - title_match = re.match(r'^Course Title:\s*(.+)$', lines[0].strip(), re.IGNORECASE) + title_match = re.match( + r"^Course Title:\s*(.+)$", lines[0].strip(), re.IGNORECASE + ) if title_match: course_title = title_match.group(1).strip() else: course_title = lines[0].strip() - + # Parse remaining lines for course metadata for i in range(1, min(len(lines), 4)): # Check first 4 lines for metadata line = lines[i].strip() if not line: continue - + # Try to match course link - link_match = re.match(r'^Course Link:\s*(.+)$', line, re.IGNORECASE) + link_match = re.match(r"^Course Link:\s*(.+)$", line, re.IGNORECASE) if link_match: course_link = link_match.group(1).strip() continue - + # Try to match instructor - instructor_match = re.match(r'^Course Instructor:\s*(.+)$', line, re.IGNORECASE) + instructor_match = re.match( + r"^Course Instructor:\s*(.+)$", line, re.IGNORECASE + ) if instructor_match: instructor_name = instructor_match.group(1).strip() continue - + # Create course object with title as ID course = Course( title=course_title, course_link=course_link, - instructor=instructor_name if instructor_name != "Unknown" else None + instructor=instructor_name if instructor_name != "Unknown" else None, ) - + # Process lessons and create chunks course_chunks = [] current_lesson = None @@ -152,108 +157,114 @@ def process_course_document(self, file_path: str) -> Tuple[Course, List[CourseCh lesson_link = None lesson_content = [] chunk_counter = 0 - + # Start processing from line 4 (after metadata) start_index = 3 if len(lines) > 3 and not lines[3].strip(): start_index = 4 # Skip empty line after instructor - + i = start_index while i < len(lines): line = lines[i] - + # Check for lesson markers (e.g., "Lesson 0: Introduction") - lesson_match = re.match(r'^Lesson\s+(\d+):\s*(.+)$', line.strip(), re.IGNORECASE) - + lesson_match = re.match( + r"^Lesson\s+(\d+):\s*(.+)$", line.strip(), re.IGNORECASE + ) + if lesson_match: # Process previous lesson if it exists if current_lesson is not None and lesson_content: - lesson_text = '\n'.join(lesson_content).strip() + lesson_text = "\n".join(lesson_content).strip() if lesson_text: # Add lesson to course lesson = Lesson( lesson_number=current_lesson, title=lesson_title, - lesson_link=lesson_link + lesson_link=lesson_link, ) course.lessons.append(lesson) - + # Create chunks for this lesson chunks = self.chunk_text(lesson_text) for idx, chunk in enumerate(chunks): # For the first chunk of each lesson, add lesson context if idx == 0: - chunk_with_context = f"Lesson {current_lesson} content: {chunk}" + chunk_with_context = ( + f"Lesson {current_lesson} content: {chunk}" + ) else: chunk_with_context = chunk - + course_chunk = CourseChunk( content=chunk_with_context, course_title=course.title, lesson_number=current_lesson, - chunk_index=chunk_counter + chunk_index=chunk_counter, ) course_chunks.append(course_chunk) chunk_counter += 1 - + # Start new lesson current_lesson = int(lesson_match.group(1)) lesson_title = lesson_match.group(2).strip() lesson_link = None - + # Check if next line is a lesson link if i + 1 < len(lines): next_line = lines[i + 1].strip() - link_match = re.match(r'^Lesson Link:\s*(.+)$', next_line, re.IGNORECASE) + link_match = re.match( + r"^Lesson Link:\s*(.+)$", next_line, re.IGNORECASE + ) if link_match: lesson_link = link_match.group(1).strip() i += 1 # Skip the link line so it's not added to content - + lesson_content = [] else: # Add line to current lesson content lesson_content.append(line) - + i += 1 - + # Process the last lesson if current_lesson is not None and lesson_content: - lesson_text = '\n'.join(lesson_content).strip() + lesson_text = "\n".join(lesson_content).strip() if lesson_text: lesson = Lesson( lesson_number=current_lesson, title=lesson_title, - lesson_link=lesson_link + lesson_link=lesson_link, ) course.lessons.append(lesson) - + chunks = self.chunk_text(lesson_text) for idx, chunk in enumerate(chunks): # For any chunk of each lesson, add lesson context & course title - + chunk_with_context = f"Course {course_title} Lesson {current_lesson} content: {chunk}" - + course_chunk = CourseChunk( content=chunk_with_context, course_title=course.title, lesson_number=current_lesson, - chunk_index=chunk_counter + chunk_index=chunk_counter, ) course_chunks.append(course_chunk) chunk_counter += 1 - + # If no lessons found, treat entire content as one document if not course_chunks and len(lines) > 2: - remaining_content = '\n'.join(lines[start_index:]).strip() + remaining_content = "\n".join(lines[start_index:]).strip() if remaining_content: chunks = self.chunk_text(remaining_content) for chunk in chunks: course_chunk = CourseChunk( content=chunk, course_title=course.title, - chunk_index=chunk_counter + chunk_index=chunk_counter, ) course_chunks.append(course_chunk) chunk_counter += 1 - + return course, course_chunks diff --git a/backend/models.py b/backend/models.py index 7f7126fa3..12ae8113e 100644 --- a/backend/models.py +++ b/backend/models.py @@ -1,22 +1,28 @@ from typing import List, Dict, Optional from pydantic import BaseModel + class Lesson(BaseModel): """Represents a lesson within a course""" + lesson_number: int # Sequential lesson number (1, 2, 3, etc.) - title: str # Lesson title + title: str # Lesson title lesson_link: Optional[str] = None # URL link to the lesson + class Course(BaseModel): """Represents a complete course with its lessons""" - title: str # Full course title (used as unique identifier) + + title: str # Full course title (used as unique identifier) course_link: Optional[str] = None # URL link to the course instructor: Optional[str] = None # Course instructor name (optional metadata) - lessons: List[Lesson] = [] # List of lessons in this course + lessons: List[Lesson] = [] # List of lessons in this course + class CourseChunk(BaseModel): """Represents a text chunk from a course for vector storage""" - content: str # The actual text content - course_title: str # Which course this chunk belongs to - lesson_number: Optional[int] = None # Which lesson this chunk is from - chunk_index: int # Position of this chunk in the document \ No newline at end of file + + content: str # The actual text content + course_title: str # Which course this chunk belongs to + lesson_number: Optional[int] = None # Which lesson this chunk is from + chunk_index: int # Position of this chunk in the document diff --git a/backend/rag_system.py b/backend/rag_system.py index 50d848c8e..c1322a9f2 100644 --- a/backend/rag_system.py +++ b/backend/rag_system.py @@ -1,147 +1,170 @@ from typing import List, Tuple, Optional, Dict import os +import re from document_processor import DocumentProcessor from vector_store import VectorStore from ai_generator import AIGenerator from session_manager import SessionManager -from search_tools import ToolManager, CourseSearchTool +from search_tools import ToolManager, CourseSearchTool, CourseOutlineTool from models import Course, Lesson, CourseChunk + class RAGSystem: """Main orchestrator for the Retrieval-Augmented Generation system""" - + def __init__(self, config): self.config = config - + # Initialize core components - self.document_processor = DocumentProcessor(config.CHUNK_SIZE, config.CHUNK_OVERLAP) - self.vector_store = VectorStore(config.CHROMA_PATH, config.EMBEDDING_MODEL, config.MAX_RESULTS) - self.ai_generator = AIGenerator(config.ANTHROPIC_API_KEY, config.ANTHROPIC_MODEL) + self.document_processor = DocumentProcessor( + config.CHUNK_SIZE, config.CHUNK_OVERLAP + ) + self.vector_store = VectorStore( + config.CHROMA_PATH, config.EMBEDDING_MODEL, config.MAX_RESULTS + ) + self.ai_generator = AIGenerator( + config.ANTHROPIC_API_KEY, config.ANTHROPIC_MODEL + ) self.session_manager = SessionManager(config.MAX_HISTORY) - + # Initialize search tools self.tool_manager = ToolManager() self.search_tool = CourseSearchTool(self.vector_store) + self.outline_tool = CourseOutlineTool(self.vector_store) self.tool_manager.register_tool(self.search_tool) - + self.tool_manager.register_tool(self.outline_tool) + def add_course_document(self, file_path: str) -> Tuple[Course, int]: """ Add a single course document to the knowledge base. - + Args: file_path: Path to the course document - + Returns: Tuple of (Course object, number of chunks created) """ try: # Process the document - course, course_chunks = self.document_processor.process_course_document(file_path) - + course, course_chunks = self.document_processor.process_course_document( + file_path + ) + # Add course metadata to vector store for semantic search self.vector_store.add_course_metadata(course) - + # Add course content chunks to vector store self.vector_store.add_course_content(course_chunks) - + return course, len(course_chunks) except Exception as e: print(f"Error processing course document {file_path}: {e}") return None, 0 - - def add_course_folder(self, folder_path: str, clear_existing: bool = False) -> Tuple[int, int]: + + def add_course_folder( + self, folder_path: str, clear_existing: bool = False + ) -> Tuple[int, int]: """ Add all course documents from a folder. - + Args: folder_path: Path to folder containing course documents clear_existing: Whether to clear existing data first - + Returns: Tuple of (total courses added, total chunks created) """ total_courses = 0 total_chunks = 0 - + # Clear existing data if requested if clear_existing: print("Clearing existing data for fresh rebuild...") self.vector_store.clear_all_data() - + if not os.path.exists(folder_path): print(f"Folder {folder_path} does not exist") return 0, 0 - + # Get existing course titles to avoid re-processing existing_course_titles = set(self.vector_store.get_existing_course_titles()) - + # Process each file in the folder for file_name in os.listdir(folder_path): file_path = os.path.join(folder_path, file_name) - if os.path.isfile(file_path) and file_name.lower().endswith(('.pdf', '.docx', '.txt')): + if os.path.isfile(file_path) and file_name.lower().endswith( + (".pdf", ".docx", ".txt") + ): try: # Check if this course might already exist # We'll process the document to get the course ID, but only add if new - course, course_chunks = self.document_processor.process_course_document(file_path) - + course, course_chunks = ( + self.document_processor.process_course_document(file_path) + ) + if course and course.title not in existing_course_titles: # This is a new course - add it to the vector store self.vector_store.add_course_metadata(course) self.vector_store.add_course_content(course_chunks) total_courses += 1 total_chunks += len(course_chunks) - print(f"Added new course: {course.title} ({len(course_chunks)} chunks)") + print( + f"Added new course: {course.title} ({len(course_chunks)} chunks)" + ) existing_course_titles.add(course.title) elif course: print(f"Course already exists: {course.title} - skipping") except Exception as e: print(f"Error processing {file_name}: {e}") - + return total_courses, total_chunks - - def query(self, query: str, session_id: Optional[str] = None) -> Tuple[str, List[str]]: + + def query( + self, query: str, session_id: Optional[str] = None + ) -> Tuple[str, List[str]]: """ Process a user query using the RAG system with tool-based search. - + Args: query: User's question session_id: Optional session ID for conversation context - + Returns: - Tuple of (response, sources list - empty for tool-based approach) + Tuple of (response, sources list, source_links list) """ # Create prompt for the AI with clear instructions prompt = f"""Answer this question about course materials: {query}""" - + # Get conversation history if session exists history = None if session_id: history = self.session_manager.get_conversation_history(session_id) - + # Generate response using AI with tools response = self.ai_generator.generate_response( query=prompt, conversation_history=history, tools=self.tool_manager.get_tool_definitions(), - tool_manager=self.tool_manager + tool_manager=self.tool_manager, ) - - # Get sources from the search tool + + # Get sources and source links from the search tool sources = self.tool_manager.get_last_sources() + source_links = self.tool_manager.get_last_source_links() # Reset sources after retrieving them self.tool_manager.reset_sources() - + # Update conversation history if session_id: self.session_manager.add_exchange(session_id, query, response) - - # Return response with sources from tool searches - return response, sources - + + # Return response with sources and links from tool searches + return response, sources, source_links + def get_course_analytics(self) -> Dict: """Get analytics about the course catalog""" return { "total_courses": self.vector_store.get_course_count(), - "course_titles": self.vector_store.get_existing_course_titles() - } \ No newline at end of file + "course_titles": self.vector_store.get_existing_course_titles(), + } diff --git a/backend/search_tools.py b/backend/search_tools.py index adfe82352..d1a606eae 100644 --- a/backend/search_tools.py +++ b/backend/search_tools.py @@ -5,12 +5,12 @@ class Tool(ABC): """Abstract base class for all tools""" - + @abstractmethod def get_tool_definition(self) -> Dict[str, Any]: """Return Anthropic tool definition for this tool""" pass - + @abstractmethod def execute(self, **kwargs) -> str: """Execute the tool with given parameters""" @@ -19,11 +19,11 @@ def execute(self, **kwargs) -> str: class CourseSearchTool(Tool): """Tool for searching course content with semantic course name matching""" - + def __init__(self, vector_store: VectorStore): self.store = vector_store self.last_sources = [] # Track sources from last search - + def get_tool_definition(self) -> Dict[str, Any]: """Return Anthropic tool definition for this tool""" return { @@ -33,46 +33,49 @@ def get_tool_definition(self) -> Dict[str, Any]: "type": "object", "properties": { "query": { - "type": "string", - "description": "What to search for in the course content" + "type": "string", + "description": "What to search for in the course content", }, "course_name": { "type": "string", - "description": "Course title (partial matches work, e.g. 'MCP', 'Introduction')" + "description": "Course title (partial matches work, e.g. 'MCP', 'Introduction')", }, "lesson_number": { "type": "integer", - "description": "Specific lesson number to search within (e.g. 1, 2, 3)" - } + "description": "Specific lesson number to search within (e.g. 1, 2, 3)", + }, }, - "required": ["query"] - } + "required": ["query"], + }, } - - def execute(self, query: str, course_name: Optional[str] = None, lesson_number: Optional[int] = None) -> str: + + def execute( + self, + query: str, + course_name: Optional[str] = None, + lesson_number: Optional[int] = None, + ) -> str: """ Execute the search tool with given parameters. - + Args: query: What to search for course_name: Optional course filter lesson_number: Optional lesson filter - + Returns: Formatted search results or error message """ - + # Use the vector store's unified search interface results = self.store.search( - query=query, - course_name=course_name, - lesson_number=lesson_number + query=query, course_name=course_name, lesson_number=lesson_number ) - + # Handle errors if results.error: return results.error - + # Handle empty results if results.is_empty(): filter_info = "" @@ -81,44 +84,128 @@ def execute(self, query: str, course_name: Optional[str] = None, lesson_number: if lesson_number: filter_info += f" in lesson {lesson_number}" return f"No relevant content found{filter_info}." - + # Format and return results return self._format_results(results) - + def _format_results(self, results: SearchResults) -> str: """Format search results with course and lesson context""" formatted = [] sources = [] # Track sources for the UI - + source_links = [] # Track lesson links for the UI + for doc, meta in zip(results.documents, results.metadata): - course_title = meta.get('course_title', 'unknown') - lesson_num = meta.get('lesson_number') - + course_title = meta.get("course_title", "unknown") + lesson_num = meta.get("lesson_number") + # Build context header header = f"[{course_title}" if lesson_num is not None: header += f" - Lesson {lesson_num}" header += "]" - + # Track source for the UI source = course_title if lesson_num is not None: source += f" - Lesson {lesson_num}" sources.append(source) - + + # Get lesson link if available + lesson_link = None + if lesson_num is not None: + lesson_link = self.store.get_lesson_link(course_title, lesson_num) + source_links.append(lesson_link) + formatted.append(f"{header}\n{doc}") - + # Store sources for retrieval self.last_sources = sources - + self.last_source_links = source_links + return "\n\n".join(formatted) + +class CourseOutlineTool(Tool): + """Tool for retrieving complete course outlines and lesson lists""" + + def __init__(self, vector_store: VectorStore): + self.store = vector_store + + def get_tool_definition(self) -> Dict[str, Any]: + """Return Anthropic tool definition for this tool""" + return { + "name": "get_course_outline", + "description": "Get the complete outline of a course including title, link, and all lessons with their numbers and titles", + "input_schema": { + "type": "object", + "properties": { + "course_name": { + "type": "string", + "description": "Course title (partial matches work, e.g. 'MCP', 'Introduction')", + } + }, + "required": ["course_name"], + }, + } + + def execute(self, course_name: str) -> str: + """ + Execute the outline retrieval tool. + + Args: + course_name: Course name to get outline for + + Returns: + Formatted course outline or error message + """ + # Resolve course name using vector store's existing method + course_title = self.store._resolve_course_name(course_name) + if not course_title: + return f"No course found matching '{course_name}'" + + # Get course metadata + try: + results = self.store.course_catalog.get(ids=[course_title]) + if not results or not results["metadatas"]: + return f"Course metadata not found for '{course_title}'" + + metadata = results["metadatas"][0] + + # Parse lessons from JSON + import json + + lessons_json = metadata.get("lessons_json") + if not lessons_json: + return f"No lesson information available for '{course_title}'" + + lessons = json.loads(lessons_json) + + # Format the outline with proper markdown + header = ( + f"**Course Title:** {metadata.get('title', course_title)}\n\n" + f"**Course Link:** {metadata.get('course_link', 'N/A')}\n\n" + f"**Total Lessons:** {len(lessons)}\n\n" + f"**Lesson Outline:**\n" + ) + + lesson_lines = [] + for lesson in lessons: + lesson_num = lesson.get("lesson_number", "N/A") + lesson_title = lesson.get("lesson_title", "N/A") + lesson_lines.append(f"- **Lesson {lesson_num}:** {lesson_title}") + + return header + "\n".join(lesson_lines) + + except Exception as e: + return f"Error retrieving course outline: {str(e)}" + + class ToolManager: """Manages available tools for the AI""" - + def __init__(self): self.tools = {} - + def register_tool(self, tool: Tool): """Register any tool that implements the Tool interface""" tool_def = tool.get_tool_definition() @@ -127,28 +214,37 @@ def register_tool(self, tool: Tool): raise ValueError("Tool must have a 'name' in its definition") self.tools[tool_name] = tool - def get_tool_definitions(self) -> list: """Get all tool definitions for Anthropic tool calling""" return [tool.get_tool_definition() for tool in self.tools.values()] - + def execute_tool(self, tool_name: str, **kwargs) -> str: """Execute a tool by name with given parameters""" if tool_name not in self.tools: return f"Tool '{tool_name}' not found" - + return self.tools[tool_name].execute(**kwargs) - + def get_last_sources(self) -> list: """Get sources from the last search operation""" # Check all tools for last_sources attribute for tool in self.tools.values(): - if hasattr(tool, 'last_sources') and tool.last_sources: + if hasattr(tool, "last_sources") and tool.last_sources: return tool.last_sources return [] + def get_last_source_links(self) -> list: + """Get source links from the last search operation""" + # Check all tools for last_source_links attribute + for tool in self.tools.values(): + if hasattr(tool, "last_source_links") and tool.last_source_links: + return tool.last_source_links + return [] + def reset_sources(self): """Reset sources from all tools that track sources""" for tool in self.tools.values(): - if hasattr(tool, 'last_sources'): - tool.last_sources = [] \ No newline at end of file + if hasattr(tool, "last_sources"): + tool.last_sources = [] + if hasattr(tool, "last_source_links"): + tool.last_source_links = [] diff --git a/backend/session_manager.py b/backend/session_manager.py index a5a96b1a1..9e17f346b 100644 --- a/backend/session_manager.py +++ b/backend/session_manager.py @@ -1,61 +1,66 @@ from typing import Dict, List, Optional from dataclasses import dataclass + @dataclass class Message: """Represents a single message in a conversation""" - role: str # "user" or "assistant" + + role: str # "user" or "assistant" content: str # The message content + class SessionManager: """Manages conversation sessions and message history""" - + def __init__(self, max_history: int = 5): self.max_history = max_history self.sessions: Dict[str, List[Message]] = {} self.session_counter = 0 - + def create_session(self) -> str: """Create a new conversation session""" self.session_counter += 1 session_id = f"session_{self.session_counter}" self.sessions[session_id] = [] return session_id - + def add_message(self, session_id: str, role: str, content: str): """Add a message to the conversation history""" if session_id not in self.sessions: self.sessions[session_id] = [] - + message = Message(role=role, content=content) self.sessions[session_id].append(message) - + # Keep conversation history within limits if len(self.sessions[session_id]) > self.max_history * 2: - self.sessions[session_id] = self.sessions[session_id][-self.max_history * 2:] - + self.sessions[session_id] = self.sessions[session_id][ + -self.max_history * 2 : + ] + def add_exchange(self, session_id: str, user_message: str, assistant_message: str): """Add a complete question-answer exchange""" self.add_message(session_id, "user", user_message) self.add_message(session_id, "assistant", assistant_message) - + def get_conversation_history(self, session_id: Optional[str]) -> Optional[str]: """Get formatted conversation history for a session""" if not session_id or session_id not in self.sessions: return None - + messages = self.sessions[session_id] if not messages: return None - + # Format messages for context formatted_messages = [] for msg in messages: formatted_messages.append(f"{msg.role.title()}: {msg.content}") - + return "\n".join(formatted_messages) - + def clear_session(self, session_id: str): """Clear all messages from a session""" if session_id in self.sessions: - self.sessions[session_id] = [] \ No newline at end of file + self.sessions[session_id] = [] diff --git a/backend/tests/__init__.py b/backend/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py new file mode 100644 index 000000000..1cdf8ab3c --- /dev/null +++ b/backend/tests/conftest.py @@ -0,0 +1,43 @@ +import sys +import os + +# Add backend and tests directories to path so imports work +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, os.path.dirname(__file__)) + +import pytest +from unittest.mock import MagicMock, patch +from helpers import MockConfig, make_valid_search_results + + +@pytest.fixture +def mock_config(): + """Shared MockConfig instance.""" + return MockConfig() + + +@pytest.fixture +def mock_rag_system(): + """A MagicMock standing in for RAGSystem with pre-wired sub-components.""" + rag = MagicMock() + rag.session_manager.create_session.return_value = "test-session-123" + rag.query.return_value = ( + "This is a test answer.", + ["Source A", "Source B"], + ["http://example.com/a", "http://example.com/b"], + ) + rag.get_course_analytics.return_value = { + "total_courses": 2, + "course_titles": ["Course A", "Course B"], + } + return rag + + +@pytest.fixture +def mock_vector_store(): + """A MagicMock standing in for VectorStore.""" + store = MagicMock() + store.get_course_count.return_value = 2 + store.get_existing_course_titles.return_value = ["Course A", "Course B"] + store.search.return_value = make_valid_search_results(2) + return store diff --git a/backend/tests/helpers.py b/backend/tests/helpers.py new file mode 100644 index 000000000..9e3e9dcfa --- /dev/null +++ b/backend/tests/helpers.py @@ -0,0 +1,67 @@ +"""Shared test helpers and factories.""" + +import sys +import os +from dataclasses import dataclass +from unittest.mock import MagicMock + +# Add backend to path so imports work +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from vector_store import SearchResults + + +@dataclass +class MockConfig: + ANTHROPIC_API_KEY: str = "test-key" + ANTHROPIC_MODEL: str = "claude-sonnet-4-20250514" + EMBEDDING_MODEL: str = "all-MiniLM-L6-v2" + CHUNK_SIZE: int = 800 + CHUNK_OVERLAP: int = 100 + MAX_RESULTS: int = 5 + MAX_HISTORY: int = 2 + CHROMA_PATH: str = "./test_chroma_db" + + +def make_search_results(documents=None, metadata=None, distances=None, error=None): + """Factory for SearchResults objects.""" + if error: + return SearchResults.empty(error) + return SearchResults( + documents=documents or [], + metadata=metadata or [], + distances=distances or [], + error=None, + ) + + +def make_valid_search_results(n=2): + """Create valid search results with n items.""" + docs = [f"Content about topic {i}" for i in range(n)] + meta = [ + {"course_title": f"Course {i}", "lesson_number": i + 1, "chunk_index": i} + for i in range(n) + ] + dists = [0.1 * (i + 1) for i in range(n)] + return SearchResults(documents=docs, metadata=meta, distances=dists) + + +def make_anthropic_response(content_blocks, stop_reason="end_turn"): + """Factory for mock Anthropic API responses.""" + mock_response = MagicMock() + mock_response.stop_reason = stop_reason + + blocks = [] + for block in content_blocks: + mock_block = MagicMock() + mock_block.type = block["type"] + if block["type"] == "text": + mock_block.text = block["text"] + elif block["type"] == "tool_use": + mock_block.id = block["id"] + mock_block.name = block["name"] + mock_block.input = block["input"] + blocks.append(mock_block) + + mock_response.content = blocks + return mock_response diff --git a/backend/tests/test_ai_generator.py b/backend/tests/test_ai_generator.py new file mode 100644 index 000000000..1769c71ab --- /dev/null +++ b/backend/tests/test_ai_generator.py @@ -0,0 +1,421 @@ +"""Tests for AIGenerator tool calling and response handling.""" + +import pytest +import anthropic +from unittest.mock import MagicMock, patch, call +from helpers import make_anthropic_response +from ai_generator import AIGenerator + + +@pytest.fixture +def generator(): + with patch("ai_generator.anthropic.Anthropic"): + gen = AIGenerator(api_key="test-key", model="claude-sonnet-4-20250514") + return gen + + +@pytest.fixture +def tool_manager(): + tm = MagicMock() + tm.execute_tool.return_value = "Tool result: content about topic" + return tm + + +@pytest.fixture +def sample_tools(): + return [ + { + "name": "search_course_content", + "description": "Search course materials", + "input_schema": { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], + }, + } + ] + + +class TestNoToolUsage: + def test_direct_text_response(self, generator, sample_tools): + """When Claude returns text (no tools), returns text directly.""" + response = make_anthropic_response( + [{"type": "text", "text": "Hello, I can help!"}], + stop_reason="end_turn", + ) + generator.client.messages.create.return_value = response + + result = generator.generate_response(query="hi", tools=sample_tools) + + assert result == "Hello, I can help!" + + def test_empty_content_returns_fallback(self, generator, sample_tools): + """When response.content is empty, returns a fallback message instead of crashing.""" + response = make_anthropic_response([], stop_reason="end_turn") + response.content = [] # explicitly empty + generator.client.messages.create.return_value = response + + result = generator.generate_response(query="test", tools=sample_tools) + + assert "able to generate a response" in result.lower() + + +class TestSingleToolRound: + def test_tool_use_calls_tool_manager(self, generator, tool_manager, sample_tools): + """When Claude returns tool_use, calls tool_manager.execute_tool().""" + tool_response = make_anthropic_response( + [ + { + "type": "tool_use", + "id": "t1", + "name": "search_course_content", + "input": {"query": "neural networks"}, + } + ], + stop_reason="tool_use", + ) + text_response = make_anthropic_response( + [{"type": "text", "text": "Neural networks are..."}], + stop_reason="end_turn", + ) + generator.client.messages.create.side_effect = [ + tool_response, + text_response, + ] + + generator.generate_response( + query="what are neural networks", + tools=sample_tools, + tool_manager=tool_manager, + ) + + tool_manager.execute_tool.assert_called_once_with( + "search_course_content", query="neural networks" + ) + + def test_tool_use_then_synthesis(self, generator, tool_manager, sample_tools): + """Round 1: tool_use -> execute -> Round 2: Claude synthesizes answer.""" + tool_response = make_anthropic_response( + [ + { + "type": "tool_use", + "id": "t1", + "name": "search_course_content", + "input": {"query": "transformers"}, + } + ], + stop_reason="tool_use", + ) + synthesis_response = make_anthropic_response( + [{"type": "text", "text": "Transformers use attention mechanisms."}], + stop_reason="end_turn", + ) + generator.client.messages.create.side_effect = [ + tool_response, + synthesis_response, + ] + + result = generator.generate_response( + query="explain transformers", + tools=sample_tools, + tool_manager=tool_manager, + ) + + assert result == "Transformers use attention mechanisms." + assert generator.client.messages.create.call_count == 2 + + def test_course_outline_returns_directly( + self, generator, tool_manager, sample_tools + ): + """get_course_outline tool result is returned directly without synthesis.""" + outline_result = "**Course Title:** MCP\n- Lesson 1: Intro" + tool_manager.execute_tool.return_value = outline_result + + tool_response = make_anthropic_response( + [ + { + "type": "tool_use", + "id": "t1", + "name": "get_course_outline", + "input": {"course_name": "MCP"}, + } + ], + stop_reason="tool_use", + ) + generator.client.messages.create.return_value = tool_response + + result = generator.generate_response( + query="outline of MCP", + tools=sample_tools, + tool_manager=tool_manager, + ) + + assert result == outline_result + # Should NOT make a second API call for synthesis + assert generator.client.messages.create.call_count == 1 + + +class TestMultiRoundToolCalling: + def test_two_rounds_of_tool_calls(self, generator, tool_manager, sample_tools): + """Loop executes up to 2 tool rounds before final synthesis call.""" + tool_response_1 = make_anthropic_response( + [ + { + "type": "tool_use", + "id": "t1", + "name": "search_course_content", + "input": {"query": "round 1"}, + } + ], + stop_reason="tool_use", + ) + tool_response_2 = make_anthropic_response( + [ + { + "type": "tool_use", + "id": "t2", + "name": "search_course_content", + "input": {"query": "round 2"}, + } + ], + stop_reason="tool_use", + ) + final_response = make_anthropic_response( + [{"type": "text", "text": "Final answer after 2 rounds."}], + stop_reason="end_turn", + ) + generator.client.messages.create.side_effect = [ + tool_response_1, + tool_response_2, + final_response, + ] + + result = generator.generate_response( + query="complex question", + tools=sample_tools, + tool_manager=tool_manager, + ) + + assert result == "Final answer after 2 rounds." + # 2 tool rounds + 1 final synthesis = 3 API calls + assert generator.client.messages.create.call_count == 3 + assert tool_manager.execute_tool.call_count == 2 + + def test_messages_accumulate_across_rounds( + self, generator, tool_manager, sample_tools + ): + """2nd API call's messages kwarg contains tool results from round 1.""" + tool_manager.execute_tool.return_value = "Result from round 1" + + tool_response_1 = make_anthropic_response( + [ + { + "type": "tool_use", + "id": "t1", + "name": "search_course_content", + "input": {"query": "round 1"}, + } + ], + stop_reason="tool_use", + ) + text_response = make_anthropic_response( + [{"type": "text", "text": "Synthesized answer."}], + stop_reason="end_turn", + ) + generator.client.messages.create.side_effect = [ + tool_response_1, + text_response, + ] + + generator.generate_response( + query="test query", + tools=sample_tools, + tool_manager=tool_manager, + ) + + # The 2nd API call should have accumulated messages + second_call_kwargs = generator.client.messages.create.call_args_list[1][1] + msgs = second_call_kwargs["messages"] + + # Should be: user query, assistant tool_use, user tool_result + assert len(msgs) == 3 + assert msgs[0] == {"role": "user", "content": "test query"} + assert msgs[1]["role"] == "assistant" + assert msgs[2]["role"] == "user" + # The tool_result content should contain our result + tool_result_content = msgs[2]["content"] + assert len(tool_result_content) == 1 + assert tool_result_content[0]["type"] == "tool_result" + assert tool_result_content[0]["content"] == "Result from round 1" + + def test_parallel_tool_calls_all_executed( + self, generator, tool_manager, sample_tools + ): + """When Claude calls 2 tools in one response, both execute.""" + tool_manager.execute_tool.side_effect = ["Result A", "Result B"] + + parallel_response = make_anthropic_response( + [ + { + "type": "tool_use", + "id": "t1", + "name": "search_course_content", + "input": {"query": "topic A"}, + }, + { + "type": "tool_use", + "id": "t2", + "name": "search_course_content", + "input": {"query": "topic B"}, + }, + ], + stop_reason="tool_use", + ) + synthesis_response = make_anthropic_response( + [{"type": "text", "text": "Combined answer."}], + stop_reason="end_turn", + ) + generator.client.messages.create.side_effect = [ + parallel_response, + synthesis_response, + ] + + result = generator.generate_response( + query="compare topics", + tools=sample_tools, + tool_manager=tool_manager, + ) + + assert result == "Combined answer." + assert tool_manager.execute_tool.call_count == 2 + tool_manager.execute_tool.assert_any_call( + "search_course_content", query="topic A" + ) + tool_manager.execute_tool.assert_any_call( + "search_course_content", query="topic B" + ) + + # Verify both results sent back to API + second_call_msgs = generator.client.messages.create.call_args_list[1][1][ + "messages" + ] + tool_results_msg = second_call_msgs[-1]["content"] + assert len(tool_results_msg) == 2 + assert tool_results_msg[0]["content"] == "Result A" + assert tool_results_msg[1]["content"] == "Result B" + + def test_parallel_tools_one_fails_still_returns_all_results( + self, generator, tool_manager, sample_tools + ): + """One tool fails -> error sent for that tool, successful result still included.""" + tool_manager.execute_tool.side_effect = [ + "Success result", + RuntimeError("Tool B crashed"), + ] + + parallel_response = make_anthropic_response( + [ + { + "type": "tool_use", + "id": "t1", + "name": "search_course_content", + "input": {"query": "good query"}, + }, + { + "type": "tool_use", + "id": "t2", + "name": "search_course_content", + "input": {"query": "bad query"}, + }, + ], + stop_reason="tool_use", + ) + final_response = make_anthropic_response( + [{"type": "text", "text": "Partial answer with error context."}], + stop_reason="end_turn", + ) + generator.client.messages.create.side_effect = [ + parallel_response, + final_response, + ] + + result = generator.generate_response( + query="multi tool query", + tools=sample_tools, + tool_manager=tool_manager, + ) + + assert isinstance(result, str) + # Both tools were attempted + assert tool_manager.execute_tool.call_count == 2 + + # Verify both results (success + error) sent back to API + second_call_msgs = generator.client.messages.create.call_args_list[1][1][ + "messages" + ] + tool_results_msg = second_call_msgs[-1]["content"] + assert len(tool_results_msg) == 2 + # First tool succeeded + assert tool_results_msg[0]["content"] == "Success result" + assert "is_error" not in tool_results_msg[0] + # Second tool has error + assert "Error" in tool_results_msg[1]["content"] + assert tool_results_msg[1]["is_error"] is True + + +class TestErrorHandling: + def test_tool_execution_exception_handled( + self, generator, tool_manager, sample_tools + ): + """When tool_manager raises, error is caught and loop breaks.""" + tool_manager.execute_tool.side_effect = RuntimeError("Tool crashed") + + tool_response = make_anthropic_response( + [ + { + "type": "tool_use", + "id": "t1", + "name": "search_course_content", + "input": {"query": "test"}, + } + ], + stop_reason="tool_use", + ) + # After exception, a final synthesis call is made (no tools) + final_response = make_anthropic_response( + [{"type": "text", "text": "I encountered an error."}], + stop_reason="end_turn", + ) + generator.client.messages.create.side_effect = [ + tool_response, + final_response, + ] + + result = generator.generate_response( + query="test", tools=sample_tools, tool_manager=tool_manager + ) + + # The function should still return (error is handled in _handle_tool_execution) + assert isinstance(result, str) + + def test_api_exception_wrapped_as_runtime_error(self, generator, sample_tools): + """When client.messages.create() raises APIError, it's wrapped as RuntimeError with context.""" + generator.client.messages.create.side_effect = anthropic.APIError( + message="rate limit exceeded", + request=MagicMock(), + body=None, + ) + + with pytest.raises(RuntimeError, match="Anthropic API error"): + generator.generate_response(query="test", tools=sample_tools) + + def test_auth_exception_wrapped_with_context(self, generator, sample_tools): + """When client.messages.create() raises AuthenticationError, it's wrapped with auth context.""" + generator.client.messages.create.side_effect = anthropic.AuthenticationError( + message="invalid api key", + response=MagicMock(status_code=401, headers={}), + body=None, + ) + + with pytest.raises(RuntimeError, match="authentication failed"): + generator.generate_response(query="test", tools=sample_tools) diff --git a/backend/tests/test_api_endpoints.py b/backend/tests/test_api_endpoints.py new file mode 100644 index 000000000..f7259a47b --- /dev/null +++ b/backend/tests/test_api_endpoints.py @@ -0,0 +1,172 @@ +"""Tests for FastAPI API endpoints. + +Because the production app.py mounts static files from ../frontend (which +doesn't exist in the test environment), we define a lightweight test app that +mirrors the endpoint logic and wires in a mock RAGSystem. +""" + +import pytest +from unittest.mock import MagicMock +from fastapi import FastAPI, HTTPException +from fastapi.testclient import TestClient +from pydantic import BaseModel +from typing import List, Optional, Dict, Union + + +# --------------------------------------------------------------------------- +# Pydantic models (duplicated from app.py to avoid import side-effects) +# --------------------------------------------------------------------------- + +class QueryRequest(BaseModel): + query: str + session_id: Optional[str] = None + + +class QueryResponse(BaseModel): + answer: str + sources: List[Union[str, Dict[str, str]]] + session_id: str + + +class CourseStats(BaseModel): + total_courses: int + course_titles: List[str] + + +class ClearSessionRequest(BaseModel): + session_id: str + + +# --------------------------------------------------------------------------- +# Test app factory +# --------------------------------------------------------------------------- + +def _create_test_app(rag_system: MagicMock) -> FastAPI: + """Build a minimal FastAPI app with the same endpoints as production.""" + test_app = FastAPI() + + @test_app.post("/api/query", response_model=QueryResponse) + async def query_documents(request: QueryRequest): + try: + session_id = request.session_id + if not session_id: + session_id = rag_system.session_manager.create_session() + answer, sources, _links = rag_system.query(request.query, session_id) + return QueryResponse(answer=answer, sources=sources, session_id=session_id) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @test_app.get("/api/courses", response_model=CourseStats) + async def get_course_stats(): + try: + analytics = rag_system.get_course_analytics() + return CourseStats( + total_courses=analytics["total_courses"], + course_titles=analytics["course_titles"], + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @test_app.post("/api/session/clear") + async def clear_session(request: ClearSessionRequest): + try: + rag_system.session_manager.clear_session(request.session_id) + return {"status": "success", "message": f"Session {request.session_id} cleared"} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + return test_app + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def client(mock_rag_system): + """TestClient wired to the mock RAG system.""" + app = _create_test_app(mock_rag_system) + return TestClient(app) + + +# --------------------------------------------------------------------------- +# /api/query +# --------------------------------------------------------------------------- + +class TestQueryEndpoint: + def test_query_with_session_id(self, client, mock_rag_system): + resp = client.post("/api/query", json={"query": "What is RAG?", "session_id": "s1"}) + assert resp.status_code == 200 + data = resp.json() + assert data["answer"] == "This is a test answer." + assert data["sources"] == ["Source A", "Source B"] + assert data["session_id"] == "s1" + mock_rag_system.query.assert_called_once_with("What is RAG?", "s1") + + def test_query_creates_session_when_missing(self, client, mock_rag_system): + resp = client.post("/api/query", json={"query": "Hello"}) + assert resp.status_code == 200 + data = resp.json() + assert data["session_id"] == "test-session-123" + mock_rag_system.session_manager.create_session.assert_called_once() + + def test_query_returns_dict_sources(self, client, mock_rag_system): + mock_rag_system.query.return_value = ( + "Answer", + [{"title": "Lesson 1", "link": "http://example.com"}], + [], + ) + resp = client.post("/api/query", json={"query": "test"}) + assert resp.status_code == 200 + assert resp.json()["sources"] == [{"title": "Lesson 1", "link": "http://example.com"}] + + def test_query_missing_body_returns_422(self, client): + resp = client.post("/api/query", json={}) + assert resp.status_code == 422 + + def test_query_rag_error_returns_500(self, client, mock_rag_system): + mock_rag_system.query.side_effect = RuntimeError("boom") + resp = client.post("/api/query", json={"query": "fail", "session_id": "s1"}) + assert resp.status_code == 500 + assert "boom" in resp.json()["detail"] + + +# --------------------------------------------------------------------------- +# /api/courses +# --------------------------------------------------------------------------- + +class TestCoursesEndpoint: + def test_get_courses(self, client): + resp = client.get("/api/courses") + assert resp.status_code == 200 + data = resp.json() + assert data["total_courses"] == 2 + assert data["course_titles"] == ["Course A", "Course B"] + + def test_courses_error_returns_500(self, client, mock_rag_system): + mock_rag_system.get_course_analytics.side_effect = RuntimeError("db down") + resp = client.get("/api/courses") + assert resp.status_code == 500 + assert "db down" in resp.json()["detail"] + + +# --------------------------------------------------------------------------- +# /api/session/clear +# --------------------------------------------------------------------------- + +class TestClearSessionEndpoint: + def test_clear_session_success(self, client, mock_rag_system): + resp = client.post("/api/session/clear", json={"session_id": "s1"}) + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "success" + mock_rag_system.session_manager.clear_session.assert_called_once_with("s1") + + def test_clear_session_missing_id_returns_422(self, client): + resp = client.post("/api/session/clear", json={}) + assert resp.status_code == 422 + + def test_clear_session_error_returns_500(self, client, mock_rag_system): + mock_rag_system.session_manager.clear_session.side_effect = KeyError("no session") + resp = client.post("/api/session/clear", json={"session_id": "bad"}) + assert resp.status_code == 500 diff --git a/backend/tests/test_rag_integration.py b/backend/tests/test_rag_integration.py new file mode 100644 index 000000000..41ad8463d --- /dev/null +++ b/backend/tests/test_rag_integration.py @@ -0,0 +1,108 @@ +"""Tests for RAG system query pipeline with mocked dependencies.""" + +import pytest +from unittest.mock import MagicMock, patch + + +class TestRAGQueryPipeline: + """Test the full query pipeline with mocked external dependencies.""" + + @pytest.fixture + def mock_deps(self): + """Set up mocked RAG system with all dependencies mocked.""" + with ( + patch("rag_system.DocumentProcessor"), + patch("rag_system.VectorStore"), + patch("rag_system.AIGenerator"), + patch("rag_system.SessionManager"), + patch("rag_system.CourseSearchTool"), + patch("rag_system.CourseOutlineTool"), + patch("rag_system.ToolManager") as mock_tm_cls, + ): + from rag_system import RAGSystem + from helpers import MockConfig + + config = MockConfig() + rag = RAGSystem(config) + + # rag.tool_manager is now a MagicMock instance + rag.ai_generator.generate_response.return_value = "This is the answer." + rag.tool_manager.get_last_sources.return_value = ["Course A - Lesson 1"] + rag.tool_manager.get_last_source_links.return_value = [ + "https://example.com/1" + ] + rag.session_manager.get_conversation_history.return_value = None + + yield rag + + def test_query_returns_response_and_sources(self, mock_deps): + """Happy path: returns (answer, sources, source_links) tuple.""" + rag = mock_deps + + response, sources, source_links = rag.query("What is MCP?") + + assert response == "This is the answer." + assert sources == ["Course A - Lesson 1"] + assert source_links == ["https://example.com/1"] + + def test_query_passes_tools_to_generator(self, mock_deps): + """get_tool_definitions() is passed to ai_generator.generate_response().""" + rag = mock_deps + rag.tool_manager.get_tool_definitions.return_value = [ + {"name": "search_course_content"} + ] + + rag.query("test question") + + call_kwargs = rag.ai_generator.generate_response.call_args + assert call_kwargs.kwargs["tools"] == [{"name": "search_course_content"}] + + def test_query_passes_tool_manager(self, mock_deps): + """tool_manager instance is passed to generator for tool dispatch.""" + rag = mock_deps + + rag.query("test question") + + call_kwargs = rag.ai_generator.generate_response.call_args + assert call_kwargs.kwargs["tool_manager"] is rag.tool_manager + + def test_query_collects_sources_after_response(self, mock_deps): + """Sources retrieved via get_last_sources() after generation.""" + rag = mock_deps + + rag.query("test") + + gen_call_order = rag.ai_generator.generate_response.call_args_list + src_call_order = rag.tool_manager.get_last_sources.call_args_list + assert len(gen_call_order) == 1 + assert len(src_call_order) == 1 + + def test_query_resets_sources(self, mock_deps): + """reset_sources() called after source collection.""" + rag = mock_deps + + rag.query("test") + + rag.tool_manager.reset_sources.assert_called_once() + + def test_query_exception_propagates_to_caller(self, mock_deps): + """When generator raises, exception propagates (no try/except in query()).""" + rag = mock_deps + rag.ai_generator.generate_response.side_effect = Exception("API auth failed") + + with pytest.raises(Exception, match="API auth failed"): + rag.query("test question") + + def test_query_with_session_passes_history(self, mock_deps): + """Session history is passed as conversation_history parameter.""" + rag = mock_deps + rag.session_manager.get_conversation_history.return_value = ( + "User: hi\nAssistant: hello" + ) + + rag.query("follow up question", session_id="session_1") + + call_kwargs = rag.ai_generator.generate_response.call_args + assert ( + call_kwargs.kwargs["conversation_history"] == "User: hi\nAssistant: hello" + ) diff --git a/backend/tests/test_search_tools.py b/backend/tests/test_search_tools.py new file mode 100644 index 000000000..ce6da52a6 --- /dev/null +++ b/backend/tests/test_search_tools.py @@ -0,0 +1,114 @@ +"""Tests for CourseSearchTool.execute() and ToolManager dispatch.""" + +import pytest +from unittest.mock import MagicMock, patch +from helpers import make_search_results, make_valid_search_results +from search_tools import CourseSearchTool, ToolManager +from vector_store import SearchResults + + +@pytest.fixture +def mock_store(): + store = MagicMock() + store.get_lesson_link = MagicMock(return_value="https://example.com/lesson") + return store + + +@pytest.fixture +def search_tool(mock_store): + return CourseSearchTool(mock_store) + + +@pytest.fixture +def tool_manager(search_tool): + tm = ToolManager() + tm.register_tool(search_tool) + return tm + + +class TestCourseSearchToolExecute: + def test_execute_returns_formatted_results(self, search_tool, mock_store): + """Valid search results are formatted as [Course - Lesson N]\\ncontent.""" + results = make_valid_search_results(2) + mock_store.search.return_value = results + + output = search_tool.execute(query="test query") + + assert "[Course 0 - Lesson 1]" in output + assert "Content about topic 0" in output + assert "[Course 1 - Lesson 2]" in output + assert "Content about topic 1" in output + + def test_execute_populates_sources(self, search_tool, mock_store): + """last_sources and last_source_links are populated after execution.""" + results = make_valid_search_results(2) + mock_store.search.return_value = results + + search_tool.execute(query="test query") + + assert len(search_tool.last_sources) == 2 + assert "Course 0 - Lesson 1" in search_tool.last_sources + assert len(search_tool.last_source_links) == 2 + + def test_execute_error_from_search(self, search_tool, mock_store): + """When SearchResults.error is set, execute returns the error string.""" + mock_store.search.return_value = make_search_results( + error="No course found matching 'xyz'" + ) + + output = search_tool.execute(query="test", course_name="xyz") + + assert "No course found matching 'xyz'" in output + + def test_execute_empty_results(self, search_tool, mock_store): + """When no documents found, returns 'No relevant content found'.""" + mock_store.search.return_value = make_search_results() + + output = search_tool.execute(query="nonexistent topic") + + assert "No relevant content found" in output + + def test_execute_empty_with_filters(self, search_tool, mock_store): + """Empty results with course_name/lesson filters include filter info.""" + mock_store.search.return_value = make_search_results() + + output = search_tool.execute(query="topic", course_name="MCP", lesson_number=3) + + assert "in course 'MCP'" in output + assert "in lesson 3" in output + + def test_execute_exception_propagates(self, search_tool, mock_store): + """When store.search() raises, exception propagates (not caught).""" + mock_store.search.side_effect = RuntimeError("DB connection failed") + + with pytest.raises(RuntimeError, match="DB connection failed"): + search_tool.execute(query="test") + + def test_tool_definition_schema(self, search_tool): + """Tool definition has correct name, required params, schema.""" + defn = search_tool.get_tool_definition() + + assert defn["name"] == "search_course_content" + assert defn["input_schema"]["required"] == ["query"] + assert "query" in defn["input_schema"]["properties"] + assert "course_name" in defn["input_schema"]["properties"] + assert "lesson_number" in defn["input_schema"]["properties"] + + +class TestToolManager: + def test_dispatches_correctly(self, tool_manager, mock_store): + """ToolManager.execute_tool dispatches to the right tool.""" + mock_store.search.return_value = make_valid_search_results(1) + + result = tool_manager.execute_tool("search_course_content", query="test query") + + mock_store.search.assert_called_once_with( + query="test query", course_name=None, lesson_number=None + ) + assert "[Course 0 - Lesson 1]" in result + + def test_unknown_tool_returns_error(self, tool_manager): + """Unknown tool name returns error string, not exception.""" + result = tool_manager.execute_tool("nonexistent_tool", query="test") + + assert "not found" in result.lower() diff --git a/backend/vector_store.py b/backend/vector_store.py index 390abe71c..c14e2f03c 100644 --- a/backend/vector_store.py +++ b/backend/vector_store.py @@ -5,73 +5,88 @@ from models import Course, CourseChunk from sentence_transformers import SentenceTransformer + @dataclass class SearchResults: """Container for search results with metadata""" + documents: List[str] metadata: List[Dict[str, Any]] distances: List[float] error: Optional[str] = None - + @classmethod - def from_chroma(cls, chroma_results: Dict) -> 'SearchResults': + def from_chroma(cls, chroma_results: Dict) -> "SearchResults": """Create SearchResults from ChromaDB query results""" return cls( - documents=chroma_results['documents'][0] if chroma_results['documents'] else [], - metadata=chroma_results['metadatas'][0] if chroma_results['metadatas'] else [], - distances=chroma_results['distances'][0] if chroma_results['distances'] else [] + documents=( + chroma_results["documents"][0] if chroma_results["documents"] else [] + ), + metadata=( + chroma_results["metadatas"][0] if chroma_results["metadatas"] else [] + ), + distances=( + chroma_results["distances"][0] if chroma_results["distances"] else [] + ), ) - + @classmethod - def empty(cls, error_msg: str) -> 'SearchResults': + def empty(cls, error_msg: str) -> "SearchResults": """Create empty results with error message""" return cls(documents=[], metadata=[], distances=[], error=error_msg) - + def is_empty(self) -> bool: """Check if results are empty""" return len(self.documents) == 0 + class VectorStore: """Vector storage using ChromaDB for course content and metadata""" - + def __init__(self, chroma_path: str, embedding_model: str, max_results: int = 5): self.max_results = max_results # Initialize ChromaDB client self.client = chromadb.PersistentClient( - path=chroma_path, - settings=Settings(anonymized_telemetry=False) + path=chroma_path, settings=Settings(anonymized_telemetry=False) ) - + # Set up sentence transformer embedding function - self.embedding_function = chromadb.utils.embedding_functions.SentenceTransformerEmbeddingFunction( - model_name=embedding_model + self.embedding_function = ( + chromadb.utils.embedding_functions.SentenceTransformerEmbeddingFunction( + model_name=embedding_model + ) ) - + # Create collections for different types of data - self.course_catalog = self._create_collection("course_catalog") # Course titles/instructors - self.course_content = self._create_collection("course_content") # Actual course material - + self.course_catalog = self._create_collection( + "course_catalog" + ) # Course titles/instructors + self.course_content = self._create_collection( + "course_content" + ) # Actual course material + def _create_collection(self, name: str): """Create or get a ChromaDB collection""" return self.client.get_or_create_collection( - name=name, - embedding_function=self.embedding_function + name=name, embedding_function=self.embedding_function ) - - def search(self, - query: str, - course_name: Optional[str] = None, - lesson_number: Optional[int] = None, - limit: Optional[int] = None) -> SearchResults: + + def search( + self, + query: str, + course_name: Optional[str] = None, + lesson_number: Optional[int] = None, + limit: Optional[int] = None, + ) -> SearchResults: """ Main search interface that handles course resolution and content search. - + Args: query: What to search for in course content course_name: Optional course name/title to filter by lesson_number: Optional lesson number to filter by limit: Maximum results to return - + Returns: SearchResults object with documents and metadata """ @@ -81,104 +96,111 @@ def search(self, course_title = self._resolve_course_name(course_name) if not course_title: return SearchResults.empty(f"No course found matching '{course_name}'") - + # Step 2: Build filter for content search filter_dict = self._build_filter(course_title, lesson_number) - + # Step 3: Search course content # Use provided limit or fall back to configured max_results search_limit = limit if limit is not None else self.max_results - + try: results = self.course_content.query( - query_texts=[query], - n_results=search_limit, - where=filter_dict + query_texts=[query], n_results=search_limit, where=filter_dict ) return SearchResults.from_chroma(results) except Exception as e: return SearchResults.empty(f"Search error: {str(e)}") - + def _resolve_course_name(self, course_name: str) -> Optional[str]: """Use vector search to find best matching course by name""" try: - results = self.course_catalog.query( - query_texts=[course_name], - n_results=1 - ) - - if results['documents'][0] and results['metadatas'][0]: + results = self.course_catalog.query(query_texts=[course_name], n_results=1) + + if results["documents"][0] and results["metadatas"][0]: # Return the title (which is now the ID) - return results['metadatas'][0][0]['title'] + return results["metadatas"][0][0]["title"] except Exception as e: print(f"Error resolving course name: {e}") - + return None - - def _build_filter(self, course_title: Optional[str], lesson_number: Optional[int]) -> Optional[Dict]: + + def _build_filter( + self, course_title: Optional[str], lesson_number: Optional[int] + ) -> Optional[Dict]: """Build ChromaDB filter from search parameters""" if not course_title and lesson_number is None: return None - + # Handle different filter combinations if course_title and lesson_number is not None: - return {"$and": [ - {"course_title": course_title}, - {"lesson_number": lesson_number} - ]} - + return { + "$and": [ + {"course_title": course_title}, + {"lesson_number": lesson_number}, + ] + } + if course_title: return {"course_title": course_title} - + return {"lesson_number": lesson_number} - + def add_course_metadata(self, course: Course): """Add course information to the catalog for semantic search""" import json course_text = course.title - + # Build lessons metadata and serialize as JSON string lessons_metadata = [] for lesson in course.lessons: - lessons_metadata.append({ - "lesson_number": lesson.lesson_number, - "lesson_title": lesson.title, - "lesson_link": lesson.lesson_link - }) - + lessons_metadata.append( + { + "lesson_number": lesson.lesson_number, + "lesson_title": lesson.title, + "lesson_link": lesson.lesson_link, + } + ) + self.course_catalog.add( documents=[course_text], - metadatas=[{ - "title": course.title, - "instructor": course.instructor, - "course_link": course.course_link, - "lessons_json": json.dumps(lessons_metadata), # Serialize as JSON string - "lesson_count": len(course.lessons) - }], - ids=[course.title] + metadatas=[ + { + "title": course.title, + "instructor": course.instructor, + "course_link": course.course_link, + "lessons_json": json.dumps( + lessons_metadata + ), # Serialize as JSON string + "lesson_count": len(course.lessons), + } + ], + ids=[course.title], ) - + def add_course_content(self, chunks: List[CourseChunk]): """Add course content chunks to the vector store""" if not chunks: return - + documents = [chunk.content for chunk in chunks] - metadatas = [{ - "course_title": chunk.course_title, - "lesson_number": chunk.lesson_number, - "chunk_index": chunk.chunk_index - } for chunk in chunks] + metadatas = [ + { + "course_title": chunk.course_title, + "lesson_number": chunk.lesson_number, + "chunk_index": chunk.chunk_index, + } + for chunk in chunks + ] # Use title with chunk index for unique IDs - ids = [f"{chunk.course_title.replace(' ', '_')}_{chunk.chunk_index}" for chunk in chunks] - - self.course_content.add( - documents=documents, - metadatas=metadatas, - ids=ids - ) - + ids = [ + f"{chunk.course_title.replace(' ', '_')}_{chunk.chunk_index}" + for chunk in chunks + ] + + self.course_content.add(documents=documents, metadatas=metadatas, ids=ids) + def clear_all_data(self): """Clear all data from both collections""" try: @@ -189,43 +211,46 @@ def clear_all_data(self): self.course_content = self._create_collection("course_content") except Exception as e: print(f"Error clearing data: {e}") - + def get_existing_course_titles(self) -> List[str]: """Get all existing course titles from the vector store""" try: # Get all documents from the catalog results = self.course_catalog.get() - if results and 'ids' in results: - return results['ids'] + if results and "ids" in results: + return results["ids"] return [] except Exception as e: print(f"Error getting existing course titles: {e}") return [] - + def get_course_count(self) -> int: """Get the total number of courses in the vector store""" try: results = self.course_catalog.get() - if results and 'ids' in results: - return len(results['ids']) + if results and "ids" in results: + return len(results["ids"]) return 0 except Exception as e: print(f"Error getting course count: {e}") return 0 - + def get_all_courses_metadata(self) -> List[Dict[str, Any]]: """Get metadata for all courses in the vector store""" import json + try: results = self.course_catalog.get() - if results and 'metadatas' in results: + if results and "metadatas" in results: # Parse lessons JSON for each course parsed_metadata = [] - for metadata in results['metadatas']: + for metadata in results["metadatas"]: course_meta = metadata.copy() - if 'lessons_json' in course_meta: - course_meta['lessons'] = json.loads(course_meta['lessons_json']) - del course_meta['lessons_json'] # Remove the JSON string version + if "lessons_json" in course_meta: + course_meta["lessons"] = json.loads(course_meta["lessons_json"]) + del course_meta[ + "lessons_json" + ] # Remove the JSON string version parsed_metadata.append(course_meta) return parsed_metadata return [] @@ -238,30 +263,67 @@ def get_course_link(self, course_title: str) -> Optional[str]: try: # Get course by ID (title is the ID) results = self.course_catalog.get(ids=[course_title]) - if results and 'metadatas' in results and results['metadatas']: - metadata = results['metadatas'][0] - return metadata.get('course_link') + if results and "metadatas" in results and results["metadatas"]: + metadata = results["metadatas"][0] + return metadata.get("course_link") return None except Exception as e: print(f"Error getting course link: {e}") return None - + def get_lesson_link(self, course_title: str, lesson_number: int) -> Optional[str]: """Get lesson link for a given course title and lesson number""" import json + try: # Get course by ID (title is the ID) results = self.course_catalog.get(ids=[course_title]) - if results and 'metadatas' in results and results['metadatas']: - metadata = results['metadatas'][0] - lessons_json = metadata.get('lessons_json') + if results and "metadatas" in results and results["metadatas"]: + metadata = results["metadatas"][0] + lessons_json = metadata.get("lessons_json") if lessons_json: lessons = json.loads(lessons_json) # Find the lesson with matching number for lesson in lessons: - if lesson.get('lesson_number') == lesson_number: - return lesson.get('lesson_link') + if lesson.get("lesson_number") == lesson_number: + return lesson.get("lesson_link") return None except Exception as e: print(f"Error getting lesson link: {e}") - \ No newline at end of file + return None + + def get_course_outline(self, course_name: str) -> Optional[Dict[str, Any]]: + """ + Get the complete outline of a course including all lessons. + + Args: + course_name: Course name or partial name (fuzzy matching supported) + + Returns: + Dictionary with course_title, course_link, and lessons list, or None if not found + """ + import json + + # Resolve course name using fuzzy matching + course_title = self._resolve_course_name(course_name) + if not course_title: + return None + + try: + # Get course metadata by ID (title is the ID) + results = self.course_catalog.get(ids=[course_title]) + if results and "metadatas" in results and results["metadatas"]: + metadata = results["metadatas"][0] + lessons_json = metadata.get("lessons_json") + + if lessons_json: + lessons = json.loads(lessons_json) + return { + "course_title": metadata.get("title"), + "course_link": metadata.get("course_link"), + "lessons": lessons, + } + return None + except Exception as e: + print(f"Error getting course outline: {e}") + return None diff --git a/frontend-changes.md b/frontend-changes.md new file mode 100644 index 000000000..3371786fd --- /dev/null +++ b/frontend-changes.md @@ -0,0 +1,165 @@ +# Frontend Changes + +## Change 1: Dark/Light Theme Toggle Button + +### Summary +Added a theme toggle button (sun/moon icons) in the top-right corner that switches between dark and light modes with smooth CSS transitions. User preference is persisted via `localStorage`. + +### Files Changed + +#### `frontend/index.html` +- Added a `

Course Materials Assistant

Ask questions about courses, instructors, and content

@@ -19,6 +51,14 @@

Course Materials Assistant

+ +
+ +
+
@@ -76,6 +116,6 @@

Course Materials Assistant

- + \ No newline at end of file diff --git a/frontend/script.js b/frontend/script.js index 562a8a363..ce410f51c 100644 --- a/frontend/script.js +++ b/frontend/script.js @@ -15,7 +15,8 @@ document.addEventListener('DOMContentLoaded', () => { sendButton = document.getElementById('sendButton'); totalCourses = document.getElementById('totalCourses'); courseTitles = document.getElementById('courseTitles'); - + + initTheme(); setupEventListeners(); createNewSession(); loadCourseStats(); @@ -38,6 +39,20 @@ function setupEventListeners() { sendMessage(); }); }); + + // Theme toggle + const themeToggle = document.getElementById('themeToggle'); + if (themeToggle) { + themeToggle.addEventListener('click', toggleTheme); + } + + // New chat button + const newChatButton = document.getElementById('newChatButton'); + if (newChatButton) { + newChatButton.addEventListener('click', () => { + createNewSession(); + }); + } } @@ -122,10 +137,28 @@ function addMessage(content, type, sources = null, isWelcome = false) { let html = `
${displayContent}
`; if (sources && sources.length > 0) { + // Format sources - handle both string and object formats + const formattedSources = sources.map(source => { + // Handle legacy string format (backward compatibility) + if (typeof source === 'string') { + return escapeHtml(source); + } + // Handle new object format with optional link + if (typeof source === 'object' && source.text) { + const text = escapeHtml(source.text); + if (source.link) { + // Create clickable link that opens in new tab + return `${text}`; + } + return text; + } + return ''; + }).filter(s => s); // Remove empty strings + html += `
Sources -
${sources.join(', ')}
+
${formattedSources.join(', ')}
`; } @@ -147,9 +180,111 @@ function escapeHtml(text) { // Removed removeMessage function - no longer needed since we handle loading differently async function createNewSession() { + // Store old session ID for potential backend cleanup + const oldSessionId = currentSessionId; + + // Reset session state currentSessionId = null; + + // Clear UI chatMessages.innerHTML = ''; + + // Add welcome message addMessage('Welcome to the Course Materials Assistant! I can help you with questions about courses, lessons and specific content. What would you like to know?', 'assistant', null, true); + + // Focus input for immediate use + if (chatInput) { + chatInput.focus(); + } + + // Clear backend session if one existed + if (oldSessionId) { + clearBackendSession(oldSessionId); + } +} + +async function clearBackendSession(sessionId) { + if (!sessionId) return; + + try { + await fetch(`${API_URL}/session/clear`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ session_id: sessionId }) + }); + } catch (error) { + // Silent fail - session will be garbage collected eventually + console.warn('Failed to clear backend session:', error); + } +} + +// Theme Toggle +function initTheme() { + // Theme was already applied by the inline