From 571b1faec5fd3022b65bf86cb23cdd8847fc548b Mon Sep 17 00:00:00 2001 From: Alexander Zuev Date: Sat, 19 Oct 2024 16:09:08 +0400 Subject: [PATCH] docs: updated docstrings for all modules, functions in the application --- README.md | 2 +- app.py | 103 +++++++-- chainlit_app.py | 97 -------- chainlit_poc.py | 147 ------------ src/generation/claude_assistant.py | 329 ++++++++++++++++++++++----- src/generation/summary_manager.py | 30 ++- src/generation/tool_definitions.py | 44 +++- src/processing/chunking.py | 349 +++++++++++++++++++++++++++-- src/utils/decorators.py | 47 ++++ src/utils/logger.py | 51 ++++- src/utils/output_formatter.py | 34 +++ src/vector_storage/vector_db.py | 297 ++++++++++++++++++++++-- tests/test_app.py | 6 + tests/test_chunking.py | 11 + tests/test_claude_assistant.py | 49 +++- tests/test_crawler.py | 12 + tests/test_vector_db.py | 25 +++ 17 files changed, 1265 insertions(+), 368 deletions(-) delete mode 100644 chainlit_app.py delete mode 100644 chainlit_poc.py diff --git a/README.md b/README.md index f26addc..5369426 100644 --- a/README.md +++ b/README.md @@ -170,7 +170,7 @@ See the [LICENSE](LICENSE.md) file for the full license text and additional cond The project has been renamed from **OmniClaude** to **Kollektiv** to: - avoid confusion / unintended copyright infringement of Anthropic -- emphasize the goal to become a tool to enhance collaboration through simplifying access to knowledge +- emphasize the goal to become a tool to enhance collaboration through simplifying access to knowledge - overall cool name (isn't it?) If you have any questions regarding the renaming, feel free to reach out. diff --git a/app.py b/app.py index 2483d99..edcc186 100644 --- a/app.py +++ b/app.py @@ -1,7 +1,34 @@ +import chainlit as cl + from src.core.component_initializer import ComponentInitializer from src.ui.terminal_ui import run_terminal_ui from src.utils.decorators import base_error_handler -from src.utils.logger import configure_logging +from src.utils.logger import configure_logging, get_logger + +logger = get_logger() + + +@base_error_handler +def setup_chainlit(): + """ + Set up and initialize the Chainlit environment. + + Args: + None + + Returns: + ClaudeAssistant: An instance of the ClaudeAssistant initialized with the specified documents. + + Raises: + BaseError: If there is an error during the initialization process. + """ + docs = [ + "docs_anthropic_com_en_20240928_135426-chunked.json", + "langchain-ai_github_io_langgraph_20240928_210913-chunked.json", + ] + initializer = ComponentInitializer(reset_db=False, load_all_docs=False, files=docs) + claude_assistant = initializer.init() + return claude_assistant @base_error_handler @@ -13,29 +40,73 @@ def main(debug: bool = False, reset_db: bool = False): debug (bool): Defines if debug mode should be enabled. reset_db (bool): Indicates whether the database should be reset. - Raises: - SomeCustomException: Raises a custom exception if an error occurs during initialization. - Returns: None """ # Configure logging before importing other modules configure_logging(debug=debug) - # TODO: allow users to select documents locally or pull from a db - docs = [ - "docs_anthropic_com_en_20240928_135426-chunked.json", - "langchain-ai_github_io_langgraph_20240928_210913-chunked.json", - # "docs_ragas_io_en_stable_20241015_112520-chunked.json", - ] - # Initialize components - initializer = ComponentInitializer(reset_db=reset_db, load_all_docs=False, files=docs) - claude_assistant = initializer.init() - + claude_assistant = setup_chainlit() run_terminal_ui(claude_assistant) -if __name__ == "__main__": # D100 - # TODO: reset DB should be a command a user can use +# Initialize the assistant when the Chainlit app starts +assistant = setup_chainlit() + + +@cl.on_chat_start +async def on_chat_start(): + """ + Handle the chat start event and send initial welcome messages. + + Args: + None + + Returns: + None + """ + await cl.Message(content="Hello! I'm Kollektiv, sync any web content and let's chat!").send() + + +@cl.on_message +async def handle_message(message: cl.Message): + """ + Handle an incoming message from the CL framework. + + Args: + message (cl.Message): The message object containing the user's input. + + Returns: + None + """ + if assistant is None: + logger.error("Assistant instance is not initialized.") + await cl.Message(content="Error: Assistant is not initialized.").send() + return + + response = assistant.get_response(user_input=message.content, stream=True) + + current_message = cl.Message(content="") + await current_message.send() + + tool_used = False + + for event in response: + if event["type"] == "text": + if tool_used: + # If a tool was used, start a new message for the assistant's response + current_message = cl.Message(content="") + await current_message.send() + tool_used = False + await current_message.stream_token(event["content"]) + elif event["type"] == "tool_use": + tool_name = event.get("tool", "Unknown tool") + await cl.Message(content=f"🛠️ Using {tool_name} tool.").send() + tool_used = True + + await current_message.update() + + +if __name__ == "__main__": main(debug=False, reset_db=False) diff --git a/chainlit_app.py b/chainlit_app.py deleted file mode 100644 index 1f6e677..0000000 --- a/chainlit_app.py +++ /dev/null @@ -1,97 +0,0 @@ -import chainlit as cl - -from src.core.component_initializer import ComponentInitializer -from src.utils.decorators import base_error_handler -from src.utils.logger import get_logger - -logger = get_logger() - - -# Initialize the assistant -@base_error_handler -def setup_chainlit(): - """ - Set up and initialize the Chainlit environment. - - Args: - None - - Returns: - ClaudeAssistant: An instance of the ClaudeAssistant initialized with the specified documents. - - Raises: - BaseError: If there is an error during the initialization process. - """ - docs = [ - "docs_anthropic_com_en_20240928_135426-chunked.json", - "langchain-ai_github_io_langgraph_20240928_210913-chunked.json", - # "docs_ragas_io_en_stable_20241015_112520-chunked.json", - ] - # Initialize components - initializer = ComponentInitializer(reset_db=False, load_all_docs=False, files=docs) - claude_assistant = initializer.init() - return claude_assistant - - -# Initialize the assistant when the Chainlit app starts -assistant = setup_chainlit() - - -@cl.on_chat_start -async def on_chat_start(): - """ - Handle the chat start event and send initial welcome messages. - - Args: - None - - Returns: - None - - Raises: - Exception: If there is an issue with sending messages. - """ - # TODO: send a first message to let user know which documents it has access to. - await cl.Message(content="Hello! I'm Kollektiv, sync any web content and and let's chat!").send() - - -@cl.on_message -async def handle_message(message: cl.Message): - """ - Handle an incoming message from the CL framework. - - Args: - message (cl.Message): The message object containing the user's input. - - Returns: - None - - Raises: - RuntimeError: If the assistant instance is not initialized. - """ - if assistant is None: - logger.error("Assistant instance is not initialized.") - await cl.Message(content="Error: Assistant is not initialized.").send() - return - - response = assistant.get_response(user_input=message.content, stream=True) - - current_message = cl.Message(content="") - await current_message.send() - - tool_used = False - - for event in response: - if event["type"] == "text": - if tool_used: - # If a tool was used, start a new message for the assistant's response - current_message = cl.Message(content="") - await current_message.send() - tool_used = False - await current_message.stream_token(event["content"]) - elif event["type"] == "tool_use": - tool_name = event.get("tool", "Unknown tool") - await cl.Message(content=f"🛠️ Using {tool_name} tool.").send() - tool_used = True - - await current_message.update() diff --git a/chainlit_poc.py b/chainlit_poc.py deleted file mode 100644 index 370189b..0000000 --- a/chainlit_poc.py +++ /dev/null @@ -1,147 +0,0 @@ -from typing import Any, Dict - -import chainlit as cl - - -class MockFirecrawlTool: - def index(self, indexing_info: dict[str, Any]): - print(f"Indexing with parameters: {indexing_info}") - - -firecrawl_tool = MockFirecrawlTool() - - -@cl.on_chat_start -async def on_chat_start(): - welcome_text = """ - Welcome to the Kollektiv docs bot! - - Available commands: - - @docs add "url" : Add a new URL to index - - @docs list : List all indexed URLs - - @docs remove "url" : Remove a URL from the index - - Type a command to get started. - """ - await cl.Message(content=welcome_text).send() - - -@cl.on_message -async def handle_message(message: cl.Message): - if message.content.startswith("@docs add"): - await handle_docs_add_command(message.content) - elif message.content.startswith("@docs list"): - await handle_docs_list_command() - elif message.content.startswith("@docs remove"): - await handle_docs_remove_command(message.content) - else: - await cl.Message(content="Invalid command. Use @docs add, list, or remove.").send() - - -async def handle_docs_add_command(command: str): - parts = command.split('"') - if len(parts) != 3: - await cl.Message(content='Invalid command. Use @docs add "url"').send() - return - - url = parts[1] - await start_indexing_flow(url) - - -async def start_indexing_flow(url: str): - await cl.Message(content=f"Starting indexing process for {url}").send() - - # Exclude patterns - exclude_guidance = cl.Text( - content="""Exclude Patterns Guide: - - Use patterns to exclude specific URLs or directories - - Separate multiple patterns with commas - - Examples: - - /blog/* (excludes all blog posts) - - /author/*, /tag/* (excludes author and tag pages)""", - name="Exclude Patterns Guide", - display="inline", - ) - - await cl.Message(content="Specify patterns to exclude:", elements=[exclude_guidance]).send() - exclude_patterns_msg = await cl.AskUserMessage(content="Enter exclude patterns (comma-separated):").send() - exclude_patterns = exclude_patterns_msg["output"] if exclude_patterns_msg else "" - - # Max pages - max_pages_guidance = cl.Text( - content="""Max Pages Guide: - - Set the maximum number of pages to crawl - - Use a reasonable number to avoid overloading the server - - Example: 100 (will crawl up to 100 pages)""", - name="Max Pages Guide", - display="inline", - ) - - await cl.Message(content="Specify maximum pages to crawl:", elements=[max_pages_guidance]).send() - max_pages_msg = await cl.AskUserMessage(content="Enter maximum pages number:").send() - max_pages = max_pages_msg["output"] if max_pages_msg else "" - - # Crawl depth - crawl_depth_guidance = cl.Text( - content="""Crawl Depth Guide: - - Determines how deep the crawler will go into the site structure - - Higher numbers may result in longer crawl times - - Examples: - - 1 (only the homepage) - - 3 (homepage, linked pages, and pages linked from those)""", - name="Crawl Depth Guide", - display="inline", - ) - - await cl.Message(content="Specify crawl depth:", elements=[crawl_depth_guidance]).send() - crawl_depth_msg = await cl.AskUserMessage(content="Enter crawl depth number:").send() - crawl_depth = crawl_depth_msg["output"] if crawl_depth_msg else "" - - indexing_info = { - "url": url, - "exclude_patterns": exclude_patterns, - "max_pages": max_pages, - "crawl_depth": crawl_depth, - } - - # Display collected information - info_message = f""" - Indexing content with the following parameters: - - URL: {url} - - Exclude patterns: {exclude_patterns} - - Max pages: {max_pages} - - Crawl depth: {crawl_depth} - """ - await cl.Message(content=info_message).send() - - # Mock indexing process with progress updates - await cl.Message(content="Starting indexing process...").send() - await cl.Message(content="Crawling pages...").send() - await cl.Message(content="Processing content...").send() - firecrawl_tool.index(indexing_info) - await cl.Message(content=f"Content from {url} has been indexed successfully.").send() - - await cl.Message(content="Indexing process completed. You can now ask questions about the indexed content.").send() - - -async def handle_docs_list_command(): - # This would typically fetch from a database. Using a mock list for demonstration. - indexed_urls = ["https://example.com", "https://another-example.com"] - list_message = "Currently indexed URLs:\n" + "\n".join(f"- {url}" for url in indexed_urls) - await cl.Message(content=list_message).send() - - -async def handle_docs_remove_command(command: str): - parts = command.split('"') - if len(parts) != 3: - await cl.Message(content='Invalid command. Use @docs remove "url"').send() - return - - url = parts[1] - # This would typically remove from a database. Using a mock removal for demonstration. - await cl.Message(content=f"Removing {url} from the index...").send() - await cl.Message(content=f"{url} has been removed from the index.").send() - - -if __name__ == "__main__": - cl.run() diff --git a/src/generation/claude_assistant.py b/src/generation/claude_assistant.py index b718c16..7eb93f4 100644 --- a/src/generation/claude_assistant.py +++ b/src/generation/claude_assistant.py @@ -30,7 +30,8 @@ class ConversationMessage: Args: role (str): The role of the message sender (e.g., 'user', 'system'). - content (str | list[dict[str, Any]]): The content of the message, which can be a string or a list of dictionaries. + content (str | list[dict[str, Any]]): The content of the message, which can be a string or a list of + dictionaries. Returns: None @@ -45,6 +46,16 @@ def __init__(self, role: str, content: str | list[dict[str, Any]]): self.content = content def to_dict(self, include_id: bool = False) -> dict[str, Any]: + """ + Convert the message object to a dictionary. + + Args: + include_id (bool): Whether to include the ID of the message in the dictionary. + + Returns: + dict[str, Any]: A dictionary representation of the message with keys "role" and "content". + Optionally includes "id" if `include_id` is True. + """ message_dict = {"role": self.role, "content": self.content} if include_id: message_dict["id"] = self.id @@ -70,6 +81,21 @@ def __init__(self, max_tokens: int = 200000, tokenizer: str = "cl100k_base"): self.tokenizer = tiktoken.get_encoding(tokenizer) def add_message(self, role: str, content: str | list[dict[str, Any]]) -> None: + """ + Add a message to the conversation and adjust for token limits. + + Args: + role (str): The role of the sender, e.g., 'user', 'system', or 'assistant'. + content (str or list[dict[str, Any]]): The content of the message, which can be a string or a list of + dictionaries. + + Returns: + None + + Raises: + ValueError: If the role is not recognized. + MemoryError: If the total tokens exceed the maximum allowed even after pruning. + """ message = ConversationMessage(role, content) self.messages.append(message) logger.debug(f"Added message for role={message.role}") @@ -82,6 +108,19 @@ def add_message(self, role: str, content: str | list[dict[str, Any]]) -> None: self.total_tokens += estimated_tokens def update_token_count(self, input_tokens: int, output_tokens: int) -> None: + """ + Update the total token count and prune history if maximum is exceeded. + + Args: + input_tokens (int): The number of input tokens. + output_tokens (int): The number of output tokens. + + Returns: + None + + Raises: + None + """ self.total_tokens = input_tokens + output_tokens if self.total_tokens > self.max_tokens: self._prune_history(0) @@ -104,50 +143,75 @@ def _prune_history(self, new_tokens: int) -> None: self.total_tokens -= self._estimate_tokens(removed_message.content) def remove_last_message(self) -> None: + """ + Remove the last message from the messages list and log its role. + + Args: + self: The instance of the class containing the messages list. + + Returns: + None + + Raises: + None + """ if self.messages: removed_message = self.messages.pop() logger.info(f"Removed message: {removed_message.role}") def get_conversation_history(self, debug: bool = False) -> list[dict[str, Any]]: + """ + Retrieve the conversation history. + + Args: + debug (bool): If True, include message IDs in the output. + + Returns: + list[dict[str, Any]]: A list of dictionaries representing the messages. + + Raises: + None + """ return [msg.to_dict(include_id=debug) for msg in self.messages] def log_conversation_state(self) -> None: + """ + Log the current state of the conversation. + + Args: + self: Instance of the class containing the current conversation state. + + Returns: + None + + Raises: + None + """ logger.debug(f"Conversation state: messages={len(self.messages)}, " f"Total tokens={self.total_tokens}, ") # Client Class class ClaudeAssistant(Model): """ - Define the ClaudeAssistant class for interacting with the Anthropoc API and managing conversational history. + Define the ClaudeAssistant class for managing AI assistant functionalities with various tools and configurations. Args: - vector_db (VectorDB): An instance of VectorDB for handling vector database operations. - api_key (str, optional): API key for authenticating with the Anthropoc service. Defaults to None. - model_name (str, optional): Name of the model to use for generating responses. Defaults to None. - - Attributes: - client (anthropic.Anthropic | None): Instance for handling API interactions. - vector_db (VectorDB): Instance for handling vector database operations. - api_key (str): API key for the Anthropoc service. - model_name (str): Name of the model to use for generating responses. - base_system_prompt (str): Base prompt template for the system. - system_prompt (str): Current system prompt. - conversation_history (ConversationHistory | None): History of the conversation. - retrieved_contexts (list[str]): List of retrieved contexts. - tool_manager (Any): Manager for handling tools. - tools (list[dict[str, Any]]): List of available tools. - extra_headers (dict[str, str]): Extra headers for API requests. - retriever (Any | None): Instance for handling data retrieval. + vector_db (VectorDB): The vector database instance for retrieving contextual information. + api_key (str, optional): The API key for the Anthropic client. Defaults to ANTHROPIC_API_KEY. + model_name (str, optional): The name of the model to use. Defaults to MAIN_MODEL. + + Raises: + anthropic.exceptions.AnthropicError: If there's an error initializing the Anthropic client. Methods: - __init__(self, vector_db, api_key=None, model_name=None): Initialize the assistant with the given parameters. - _init(self): Initialize the Anthropoc client and conversation history. - update_system_prompt(self, document_summaries): Update the system prompt with document summaries. - cached_system_prompt(self): Get the cached system prompt. - cached_tools(self): Get the cached tools. - preprocess_user_input(self, input_text): Preprocess user input text. - get_response(self, user_input, stream=True): Generate a response based on user input. - stream_response(self, user_input): Stream responses as they are generated. + _init: Initialize the assistant's client and tools. + update_system_prompt: Update the system prompt with document summaries. + cached_system_prompt: Get the cached system prompt as a list. + cached_tools: Get the cached tools as a list. + preprocess_user_input: Preprocess the user input to remove whitespace and newlines. + get_response: Generate a response based on user input, either as a stream or a single string. + stream_response: Handle the streaming response from the assistant and manage conversation flow. + """ client: anthropic.Anthropic | None = None @@ -261,9 +325,15 @@ def _init(self): @base_error_handler def update_system_prompt(self, document_summaries: list[dict[str, Any]]): """ - :param document_summaries: List of dictionaries containing summary information to be incorporated into - the system prompt. - :return: None + Update the system prompt with document summaries. + + Args: + document_summaries (list[dict[str, Any]]): A list of dictionaries where each dictionary + contains 'filename', 'summary', and 'keywords' keys. + + Raises: + KeyError: If any of the dictionaries in document_summaries does not contain the required keys. + Exception: For other unexpected errors. """ logger.info(f"Loading {len(document_summaries)} summaries") summaries_text = "\n\n".join( @@ -277,10 +347,33 @@ def update_system_prompt(self, document_summaries: list[dict[str, Any]]): @base_error_handler def cached_system_prompt(self) -> list[dict[str, Any]]: + """ + Retrieve the cached system prompt. + + Args: + self: Instance of the class containing the system prompt. + + Returns: + list[dict[str, Any]]: A list with a single dictionary containing the system prompt and its cache control + type. + """ return [{"type": "text", "text": self.system_prompt, "cache_control": {"type": "ephemeral"}}] @base_error_handler def cached_tools(self) -> list[dict[str, Any]]: + """ + Return a list of cached tools with specific attributes. + + Args: + None + + Returns: + list[dict[str, Any]]: A list of dictionaries where each dictionary represents a cached tool with `name`, + `description`, `input_schema`, and `cache_control` fields. + + Raises: + None + """ return [ { "name": tool["name"], @@ -293,6 +386,15 @@ def cached_tools(self) -> list[dict[str, Any]]: @base_error_handler def preprocess_user_input(self, input_text: str) -> str: + """ + Preprocess user input by removing whitespace and replacing newlines. + + Args: + input_text (str): The user input text to preprocess. + + Returns: + str: The preprocessed user input with no leading/trailing whitespace and newlines replaced by spaces. + """ # Remove any leading/trailing whitespace cleaned_input = input_text.strip() # Replace newlines with spaces @@ -302,17 +404,18 @@ def preprocess_user_input(self, input_text: str) -> str: @anthropic_error_handler def get_response(self, user_input: str, stream: bool = True) -> Generator[str] | str: """ - Generate a response based on user input, either as a stream or a single string. + Get the response from the assistant. Args: - user_input (str): The user's input string. - stream (bool): Whether to stream the response or not. Defaults to True. + user_input (str): The input provided by the user. + stream (bool): Indicator whether to stream the response. Defaults to True. Returns: - Generator[str] | str: A generator yielding the response in parts if streaming, - or a single string with the entire response if not streaming. - """ + Generator[str] | str: Stream of responses if streaming is enabled, otherwise a single response. + Raises: + anthropic_error_handler: Handles any exceptions during the response generation. + """ if stream: assistant_response_stream = self.stream_response(user_input) return assistant_response_stream @@ -322,19 +425,18 @@ def get_response(self, user_input: str, stream: bool = True) -> Generator[str] | return assistant_response @anthropic_error_handler - @weave.op() def stream_response(self, user_input: str) -> Generator[str]: """ - Handle the streaming response from the assistant and manage the conversation flow. + Stream responses from a conversation, handle tool use, and process assistant replies. Args: - user_input (str): The input string provided by the user. + user_input (str): The input provided by the user. - Yields: - dict: A dictionary containing the type of event and the related content. + Returns: + Generator[dict]: A generator yielding dictionaries with 'type' and content keys. Raises: - Exception: If there is any error while generating the response. + Exception: If an error occurs while generating the response. """ # iteration = 0 user_input = self.preprocess_user_input(user_input) @@ -388,19 +490,18 @@ def stream_response(self, user_input: str) -> Generator[str]: raise Exception(f"An error occurred: {str(e)}") from e @anthropic_error_handler - @weave.op() def not_stream_response(self, user_input: str) -> str: """ - Generate a response based on user input while handling potential tool uses and errors. + Process user input and generate an assistant response without streaming. Args: - user_input (str): The input message from the user. + user_input (str): The input string from the user. Returns: - str: The response generated by the assistant. + str: The assistant's response. Raises: - Exception: If an error occurs during response generation. + Exception: If an error occurs during processing or response generation. """ user_input = self.preprocess_user_input(user_input) self.conversation_history.add_message(role="user", content=user_input) @@ -442,9 +543,17 @@ def not_stream_response(self, user_input: str) -> str: @base_error_handler def _process_assistant_response(self, response: PromptCachingBetaMessage | Message) -> str: """ - Process the assistant's response by saving the message and updating the token count. - """ + Process the assistant's response and update the conversation history. + + Args: + response (PromptCachingBetaMessage | Message): The response object from the assistant. + Returns: + str: The text content of the assistant's response. + + Raises: + Exception: Any exception that can be raised by the base_error_handler. + """ logger.debug( f"Cached {response.usage.cache_creation_input_tokens} input tokens. \n" f"Read {response.usage.cache_read_input_tokens} tokens from cache" @@ -460,7 +569,21 @@ def _process_assistant_response(self, response: PromptCachingBetaMessage | Messa @base_error_handler def handle_tool_use(self, tool_name: str, tool_input: dict[str, Any], tool_use_id: str) -> dict[str, Any] | str: + """ + Handle tool use for specified tools. + Args: + tool_name (str): The name of the tool to be used. + tool_input (dict[str, Any]): The input parameters required by the tool. + tool_use_id (str): The unique identifier for this specific tool use. + + Returns: + dict[str, Any] | str: A dictionary containing the tool result if successful, otherwise a string with an + error message. + + Raises: + Exception: If there is any error while executing the tool. + """ try: if tool_name == "rag_search": search_results = self.use_rag_search(tool_input=tool_input) @@ -491,9 +614,19 @@ def handle_tool_use(self, tool_name: str, tool_input: dict[str, Any], tool_use_i @anthropic_error_handler @weave.op() def formulate_rag_query(self, recent_conversation_history: list[dict[str, Any]], important_context: str) -> str: - """ " - Generates a rag search query based on recent conversation history and important context generated by AI - assistant.""" + """ + Formulate a RAG search query based on recent conversation history and important context. + + Args: + recent_conversation_history (list[dict[str, Any]]): A list of conversation history dictionaries. + important_context (str): A string containing important contextual information. + + Returns: + str: A formulated search query for RAG. + + Raises: + ValueError: If 'recent_conversation_history' is empty. + """ logger.debug(f"Important context: {important_context}") if not recent_conversation_history: @@ -549,12 +682,35 @@ def formulate_rag_query(self, recent_conversation_history: list[dict[str, Any]], @base_error_handler def get_recent_context(self, n_messages: int = 6) -> list[dict[str, str]]: - """Retrieves last 6 messages (3 user messages + 3 assistant messages)""" + """ + Retrieve the most recent messages from the conversation history. + + Args: + n_messages (int): The number of recent messages to retrieve. Defaults to 6. + + Returns: + list[dict[str, str]]: A list of dictionaries representing the recent messages. + + Raises: + None + """ recent_messages = self.conversation_history.messages[-n_messages:] return [msg.to_dict() for msg in recent_messages] @base_error_handler def use_rag_search(self, tool_input: dict[str, Any]) -> list[str]: + """ + Perform a retrieval-augmented generation (RAG) search using the provided tool input. + + Args: + tool_input (dict[str, Any]): A dictionary containing 'important_context' key to formulate the RAG query. + + Returns: + list[str]: A list of preprocessed ranked documents resulting from the RAG search. + + Raises: + KeyError: If 'important_context' is not found in tool_input. + """ # Get recent conversation context (last n messages for each role) recent_conversation_history = self.get_recent_context() @@ -581,7 +737,16 @@ def use_rag_search(self, tool_input: dict[str, Any]) -> list[str]: @base_error_handler def preprocess_ranked_documents(self, ranked_documents: dict[str, Any]) -> list[str]: """ - Converts ranked documents into a structured string for passing to the Claude API. + Preprocess ranked documents to generate a list of formatted document strings. + + Args: + ranked_documents (dict[str, Any]): A dictionary where keys are document identifiers and values are + dictionaries + containing document details such as 'relevance_score' and 'text'. + + Returns: + list[str]: A list of formatted document strings, each containing the document's relevance score and text. + """ preprocessed_context = [] @@ -600,6 +765,20 @@ def preprocess_ranked_documents(self, ranked_documents: dict[str, Any]) -> list[ @anthropic_error_handler @weave.op() def generate_multi_query(self, query: str, model: str = None, n_queries: int = 5) -> list[str]: + """ + Generate multiple related queries based on a user query. + + Args: + query (str): The original user query. + model (str, optional): The model used for generating queries. Defaults to None. + n_queries (int, optional): The number of related queries to generate. Defaults to 5. + + Returns: + list[str]: A list of generated queries related to the original query. + + Raises: + Exception: If there is an error in the message generation process. + """ prompt = f""" You are an AI assistant whose task is to generate multiple queries as part of a RAG system. You are helping users retrieve relevant information from a vector database. @@ -630,7 +809,17 @@ def generate_multi_query(self, query: str, model: str = None, n_queries: int = 5 @base_error_handler def combine_queries(self, user_query: str, generated_queries: list[str]) -> list[str]: """ - Combines user query and generated queries into a list, removing any empty queries. + Combine user query with generated queries. + + Args: + user_query (str): The initial user-provided query. + generated_queries (list[str]): A list of queries generated by the system. + + Returns: + list[str]: A list containing the user query and the filtered generated queries. + + Raises: + None """ combined_queries = [query for query in [user_query] + generated_queries if query.strip()] return combined_queries @@ -638,7 +827,19 @@ def combine_queries(self, user_query: str, generated_queries: list[str]) -> list @anthropic_error_handler @weave.op() async def predict(self, question: str) -> dict: - """Should match the keys in the Dataset that is passed for evaluation""" + """ + Predict the answer to the given question. + + Args: + question (str): The question for which an answer is to be predicted. + + Returns: + dict: A dictionary containing the answer and the contexts retrieved. + + Raises: + TypeError: If the `question` is not of type `str`. + Exception: If there is an error in getting the response. + """ logger.debug(f"Predict method called with row: {question}") # user_input = row.get('question', '') @@ -659,5 +860,19 @@ async def predict(self, question: str) -> dict: } def reset_conversation(self): + """ + Reset the conversation state to its initial state. + + Resets the conversation history and clears any retrieved contexts. + + Args: + self: The instance of the class calling this method. + + Returns: + None + + Raises: + None + """ self.conversation_history = ConversationHistory() - self.retrieved_contexts = [] \ No newline at end of file + self.retrieved_contexts = [] diff --git a/src/generation/summary_manager.py b/src/generation/summary_manager.py index 63202de..42a91ef 100644 --- a/src/generation/summary_manager.py +++ b/src/generation/summary_manager.py @@ -16,7 +16,22 @@ class SummaryManager: - """Manages document summaries and keyword extraction from text data.""" + """ + Manages document summaries, including generation, storage, and retrieval. + + Methods: + generate_document_summary: Generates a document summary and keywords. + _select_diverse_chunks: Selects a diverse subset of chunks. + _summarize_content_structure: Summarizes the content structure based on headers. + _format_content_samples: Formats content samples for display. + _parse_summary: Parses the summary and keywords from an Anthropic Message. + _extract_data_from_text: Extracts summary and keywords from text if JSON parsing fails. + load_summaries: Loads document summaries from a JSON file. + save_summaries: Saves the document summaries to a JSON file. + get_all_summaries: Returns a list of all document summaries. + clear_summaries: Clears the document summaries and removes the summaries file. + process_file: Processes the file and generates or loads its summary. + """ def __init__(self, model_name: str = MAIN_MODEL): weave.init(project_name=WEAVE_PROJECT_NAME) @@ -232,6 +247,19 @@ def clear_summaries(self): @base_error_handler def process_file(self, data: list[dict], file_name: str): + """ + Process the file and generate or load its summary. + + Args: + data (list[dict]): The list of data dictionaries to be processed. + file_name (str): The name of the file to process. + + Returns: + None + + Raises: + None + """ if file_name in self.summaries: result = self.summaries[file_name] logger.info(f"Loading existing summary for {file_name}.") diff --git a/src/generation/tool_definitions.py b/src/generation/tool_definitions.py index a83ca4e..83f769d 100644 --- a/src/generation/tool_definitions.py +++ b/src/generation/tool_definitions.py @@ -3,12 +3,12 @@ class Tool: """ - Represent a tool with a name, description, and input schema. + Define a tool with a name, description, and input schema. Args: name (str): The name of the tool. description (str): A brief description of the tool. - input_schema (dict[str, Any]): A dictionary representing the input schema of the tool. + input_schema (dict[str, Any]): The input schema of the tool. """ def __init__(self, name: str, description: str, input_schema: dict[str, Any]): @@ -17,6 +17,13 @@ def __init__(self, name: str, description: str, input_schema: dict[str, Any]): self.input_schema = input_schema def to_dict(self) -> dict[str, Any]: + """ + Convert the object's attributes to a dictionary. + + Returns: + dict[str, Any]: A dictionary containing the object's attributes. + + """ return {"name": self.name, "description": self.description, "input_schema": self.input_schema} @@ -28,19 +35,50 @@ class ToolManager: None Attributes: - tools (dict[str, Any]): A dictionary storing tools with their name as the key. + tools (Dict[str, Any]): A dictionary to store tools by their names. """ def __init__(self): self.tools: dict[str, Any] = {} def add_tool(self, tool: Tool): + """ + Add a tool to the tools collection. + + Args: + tool (Tool): The tool instance to be added to the collection. + + Returns: + None + + Raises: + ValueError: If a tool with the same name already exists in the collection. + """ self.tools[tool.name] = tool def get_tool(self, name: str) -> Tool: + """ + Retrieve a tool by its name from the tools dictionary. + + Args: + name (str): The name of the tool to retrieve. + + Returns: + Tool: The tool associated with the given name. + + Raises: + KeyError: If no tool exists with the provided name. + """ return self.tools[name] def get_all_tools(self) -> list[dict[str, Any]]: + """ + Retrieve all tools as a list of dictionaries. + + Returns: + list[dict[str, Any]]: A list where each item is a dictionary representation of a tool. + + """ return [tool.to_dict() for tool in self.tools.values()] diff --git a/src/processing/chunking.py b/src/processing/chunking.py index fbb7248..7b15853 100644 --- a/src/processing/chunking.py +++ b/src/processing/chunking.py @@ -84,7 +84,19 @@ def __init__( @base_error_handler def load_data(self) -> dict[str, Any]: - """Loads markdown from JSON and prepares for chunking""" + """ + Load data from a JSON file and return as a dictionary. + + Args: + None + + Returns: + dict[str, Any]: The JSON content parsed as a dictionary. + + Raises: + FileNotFoundError: If the JSON file is not found. + json.JSONDecodeError: If the JSON file has invalid content. + """ input_filepath = os.path.join(RAW_DATA_DIR, self.input_filename) try: @@ -101,7 +113,18 @@ def load_data(self) -> dict[str, Any]: @base_error_handler def remove_images(self, content: str) -> str: - """Removes all types of images from the content.""" + """ + Remove various forms of image links and tags from a given content string. + + Args: + content (str): The text content possibly containing image links and tags. + + Returns: + str: The content with all image links and tags removed. + + Raises: + Exception: Raised if there are any issues during the execution of the function. + """ # Remove HTML img tags (in case any slipped through from FireCrawl) content = re.sub(r"]+>", "", content) @@ -123,7 +146,19 @@ def remove_images(self, content: str) -> str: @base_error_handler def process_pages(self, json_input: dict[str, Any]) -> list[dict[str, Any]]: - """Iterates through each page in the loaded data""" + """ + Process pages from JSON input and generate data chunks. + + Args: + json_input (dict[str, Any]): The input JSON containing page data and metadata. + + Returns: + list[dict[str, Any]]: A list of processed data chunks. + + Raises: + KeyError: If the JSON input does not contain the required keys. + ValueError: If there is an issue with the page content processing. + """ all_chunks = [] for _index, page in enumerate(json_input["data"]): page_content = page["markdown"] @@ -151,7 +186,18 @@ def process_pages(self, json_input: dict[str, Any]) -> list[dict[str, Any]]: @base_error_handler def remove_boilerplate(self, content: str) -> str: - """Removes navigation and boilerplate content from markdown.""" + """ + Remove boilerplate text from the given content. + + Args: + content (str): The content from which to remove the boilerplate text. + + Returns: + str: The content with the boilerplate text removed and extra newlines cleaned up. + + Raises: + AssertionError: If content is not of type str. + """ # Use precompiled regex cleaned_content = self.boilerplate_regex.sub("", content) # Remove any extra newlines left after removing boilerplate @@ -160,7 +206,18 @@ def remove_boilerplate(self, content: str) -> str: @base_error_handler def clean_header_text(self, header_text: str) -> str: - """Cleans unwanted markdown elements and artifacts from header text.""" + """ + Clean header text by removing specific markdown and zero-width spaces. + + Args: + header_text (str): The text to be cleaned. + + Returns: + str: The cleaned header text. + + Raises: + ValueError: If `header_text` is not a string. + """ # Remove zero-width spaces cleaned_text = header_text.replace("\u200b", "") # Remove markdown links but keep the link text @@ -175,7 +232,19 @@ def clean_header_text(self, header_text: str) -> str: @base_error_handler def identify_sections(self, page_content: str, page_metadata: dict[str, Any]) -> list[dict[str, Any]]: - """Identifies sections in the page content based on headers and preserves markdown structures.""" + """ + Identify the sections and headers in the provided page content. + + Args: + page_content (str): The content of the page to be analyzed. + page_metadata (dict[str, Any]): Metadata of the page provided as a dictionary. + + Returns: + list[dict[str, Any]]: A list of sections, each represented as a dictionary with headers and content. + + Raises: + ValueError: If an unclosed code block is detected. + """ sections = [] in_code_block = False code_fence = "" @@ -248,6 +317,20 @@ def identify_sections(self, page_content: str, page_metadata: dict[str, Any]) -> @base_error_handler def create_chunks(self, sections: list[dict[str, Any]], page_metadata: dict[str, Any]) -> list[dict[str, Any]]: + """ + Create chunks from sections and adjust them according to page metadata. + + Args: + sections (list[dict[str, Any]]): List of sections, where each section is a dictionary with "content" and + "headers". + page_metadata (dict[str, Any]): Metadata related to the page, used to enrich chunk metadata. + + Returns: + list[dict[str, Any]]: A list of adjusted and enriched chunks with ids, metadata, and data. + + Raises: + CustomException: If validation or adjustment fails during the chunk creation process. + """ page_chunks = [] for section in sections: section_chunks = self._split_section(section["content"], section["headers"]) @@ -283,6 +366,19 @@ def create_chunks(self, sections: list[dict[str, Any]], page_metadata: dict[str, @base_error_handler def _split_section(self, content: str, headers: dict[str, str]) -> list[dict[str, Any]]: # noqa: C901 + """ + Split the content into sections based on headers and code blocks. + + Args: + content (str): The content to be split. + headers (dict[str, str]): The headers associated with each content chunk. + + Returns: + list[dict[str, Any]]: A list of dictionaries, each containing 'headers' and 'content'. + + Raises: + ValidationError: If an unclosed code block is detected. + """ # TODO: refactor this method to reduce complexity # Current complexity is necessary for accurate content splitting chunks = [] @@ -386,7 +482,19 @@ def _split_section(self, content: str, headers: dict[str, str]) -> list[dict[str @base_error_handler def _split_code_block(self, code_block_content: str, code_fence: str) -> list[str]: - """Splits a code block into smaller chunks without breaking code syntax.""" + """ + Split a code block into smaller chunks based on token count. + + Args: + code_block_content (str): The content of the code block to be split. + code_fence (str): The code fence delimiter used to format the code block. + + Returns: + list[str]: A list of code block chunks. + + Raises: + None + """ lines = code_block_content.strip().split("\n") chunks = [] current_chunk_lines = [] @@ -420,7 +528,20 @@ def _split_code_block(self, code_block_content: str, code_fence: str) -> list[st @base_error_handler def _adjust_chunks(self, chunks: list[dict[str, Any]]) -> list[dict[str, Any]]: - """Adjust chunks to meet min and max token constraints by merging or splitting.""" + """ + Adjust chunks to be within the specified token limits. + + Adjusts the size of the given text chunks by merging small chunks and splitting large ones. + + Args: + chunks: A list of dictionaries, where each dictionary contains headers and content. + + Returns: + A list of dictionaries representing the adjusted chunks. + + Raises: + ValueError: If a chunk cannot be adjusted to meet the token requirements. + """ adjusted_chunks = [] i = 0 while i < len(chunks): @@ -480,7 +601,22 @@ def _adjust_chunks(self, chunks: list[dict[str, Any]]) -> list[dict[str, Any]]: @base_error_handler def _split_large_chunk(self, chunk: dict[str, Any]) -> list[dict[str, Any]]: - """Splits a chunk that exceeds 2x max_tokens into smaller chunks.""" + """ + Split a large text chunk into smaller chunks. + + Args: + chunk (dict[str, Any]): A dictionary containing 'content' and 'headers' keys. + 'content' is the text to split, and 'headers' are additional metadata. + + Returns: + list[dict[str, Any]]: A list of dictionaries where each dictionary contains a portion of the original + content + and a copy of the headers. + + Raises: + KeyError: If 'content' or 'headers' keys are not found in the chunk dictionary. + Any other exceptions raised by self._calculate_tokens method. + """ content = chunk["content"] headers = chunk["headers"] lines = content.split("\n") @@ -506,6 +642,16 @@ def _split_large_chunk(self, chunk: dict[str, Any]) -> list[dict[str, Any]]: @base_error_handler def _merge_headers(self, headers1: dict[str, str], headers2: dict[str, str]) -> dict[str, str]: + """ + Merge two headers dictionaries by levels. + + Args: + headers1 (dict[str, str]): The first headers dictionary. + headers2 (dict[str, str]): The second headers dictionary. + + Returns: + dict[str, str]: The merged headers dictionary with levels "h1", "h2", and "h3". + """ merged = {} for level in ["h1", "h2", "h3"]: header1 = headers1.get(level, "").strip() @@ -522,6 +668,20 @@ def _merge_headers(self, headers1: dict[str, str], headers2: dict[str, str]) -> def _add_overlap( self, chunks: list[dict[str, Any]], min_overlap_tokens: int = 50, max_overlap_tokens: int = 100 ) -> None: + """ + Add overlap to chunks of text based on specified token limits. + + Args: + chunks (list[dict[str, Any]]): List of text chunks with metadata. + min_overlap_tokens (int): Minimum number of tokens for the overlap. + max_overlap_tokens (int): Maximum number of tokens for the overlap. + + Returns: + None. + + Raises: + ValidationError: If adding overlap exceeds the maximum allowed tokens. + """ for i in range(1, len(chunks)): prev_chunk = chunks[i - 1] curr_chunk = chunks[i] @@ -550,7 +710,15 @@ def _add_overlap( curr_chunk["metadata"]["token_count"] += additional_tokens def _split_long_line(self, line: str) -> list[str]: - """Splits a long line into smaller chunks not exceeding 2 * max_tokens.""" + """ + Split a long line of text into smaller chunks based on token limits. + + Args: + line (str): The line of text to be split. + + Returns: + list[str]: A list containing the smaller chunks of text. + """ tokens = self.tokenizer.encode(line) max_tokens_per_chunk = 2 * self.max_tokens chunks = [] @@ -561,13 +729,34 @@ def _split_long_line(self, line: str) -> list[str]: return chunks def _get_last_n_tokens(self, text: str, n: int) -> str: + """ + Get the last n tokens from the given text. + + Args: + text (str): The text to tokenize. + n (int): The number of tokens to retrieve from the end of the text. + + Returns: + str: The decoded string of the last n tokens. + + Raises: + ValueError: If n is greater than the number of tokens in the text. + """ tokens = self.tokenizer.encode(text) last_n_tokens = tokens[-n:] return self.tokenizer.decode(last_n_tokens) @base_error_handler def save_chunks(self, chunks: list[dict[str, Any]]): - """Saves chunks to output dir""" + """ + Save the given chunks to a JSON file. + + Args: + chunks (list of dict): A list of dictionaries containing chunk data. + + Raises: + Exception: If an error occurs while saving the chunks to the file. + """ input_name = os.path.splitext(self.input_filename)[0] # Remove the extension output_filename = f"{input_name}-chunked.json" output_filepath = os.path.join(self.output_dir, output_filename) @@ -577,18 +766,47 @@ def save_chunks(self, chunks: list[dict[str, Any]]): @base_error_handler def _generate_chunk_id(self) -> uuid.UUID: - """Generates chunk's uuidv4""" + """ + Generate a new UUID for chunk identification. + + Returns: + uuid.UUID: A new unique identifier for the chunk. + + """ return uuid.uuid4() @base_error_handler def _calculate_tokens(self, text: str) -> int: - """Calculates the number of tokens in a given text using tiktoken""" + """ + Calculate the number of tokens in a given text. + + Args: + text (str): The input text to be tokenized. + + Returns: + int: The number of tokens in the input text. + + Raises: + TokenizationError: If there is an error during tokenization. + """ token_count = len(self.tokenizer.encode(text)) return token_count @base_error_handler def _create_metadata(self, page_metadata: dict[str, Any], token_count: int) -> dict[str, Any]: - """Creates metadata dictionary for a chunk""" + """ + Create metadata dictionary for a page. + + Args: + page_metadata (dict[str, Any]): Metadata extracted from a page. + token_count (int): Number of tokens in the page content. + + Returns: + dict[str, Any]: A dictionary containing the token count, source URL, and page title. + + Raises: + None + """ metadata = { "token_count": token_count, "source_url": page_metadata.get("sourceURL", ""), @@ -626,26 +844,100 @@ def __init__(self, min_chunk_size, max_tokens, output_dir, input_filename, save: self.duplicates_removed = 0 def increment_total_headings(self, level, heading_text): + """ + Add a heading text to the total_headings dictionary under the specified level. + + Args: + level (int): The level of the heading. + heading_text (str): The text of the heading to add. + + Returns: + None + + Raises: + KeyError: If the specified level does not exist in the total_headings dictionary. + """ self.total_headings[level].add(heading_text.strip()) def add_preserved_heading(self, level, heading_text): + """ + Add a heading to the preserved headings list at the specified level. + + Args: + level (int): The level at which the heading should be preserved. + heading_text (str): The text of the heading to preserve. + + Returns: + None + + Raises: + KeyError: If the specified level does not exist in headings_preserved. + """ self.headings_preserved[level].add(heading_text.strip()) def add_chunk(self, token_count): + """ + Add a chunk of tokens and update the tracking attributes. + + Args: + token_count (int): The number of tokens in the new chunk. + + Returns: + None + + Raises: + None + """ self.total_chunks += 1 self.total_tokens += token_count self.chunk_token_counts.append(token_count) def add_validation_error(self, error_message): + """ + Add a validation error message to the validation errors list. + + Args: + error_message (str): The error message to be added. + + Returns: + None + + Raises: + TypeError: If error_message is not a string. + """ self.validation_errors.append(error_message) def validate(self, chunks): + """ + Validates the given chunks by checking for duplicates and finding incorrect chunks. + + Args: + chunks: A list of data chunks to be validated. + + Returns: + None + + Raises: + ValidationError: If duplicates or incorrect chunks are found. + """ self.validate_duplicates(chunks) self.find_incorrect_chunks(chunks, save=self.save) self.log_summary() def validate_duplicates(self, chunks: list[dict[str, Any]]) -> None: - """Removes duplicate chunks and updates counts.""" + """ + Validate and remove duplicate chunks based on the text content. + + Args: + chunks (list[dict[str, Any]]): The list of chunks where each chunk is a dictionary + containing text data under the "data" key. + + Returns: + None + + Raises: + None + """ unique_chunks = {} cleaned_chunks = [] for chunk in chunks: @@ -664,7 +956,18 @@ def validate_duplicates(self, chunks: list[dict[str, Any]]) -> None: self.total_chunks = len(chunks) def log_summary(self): - """Logs a concise summary of the chunking process""" + """ + Log a summary of chunk creation, statistics, headers, validation errors, and incorrect chunks. + + Args: + self: An instance of the class containing chunk and heading info, validation errors, etc. + + Returns: + None + + Raises: + None + """ # Total chunks logger.info(f"Total chunks created: {self.total_chunks}") @@ -712,7 +1015,19 @@ def log_summary(self): logger.info(incorrect_chunks_info) def find_incorrect_chunks(self, chunks: list[dict[str, Any]], save: bool = False) -> None: - """Finds chunks below min_chunk_size or above 2x max_tokens and saves to JSON.""" + """ + Identify chunks that are too small or too large and optionally save them to a file. + + Args: + chunks (list[dict[str, Any]]): List of chunk dictionaries containing metadata and data for each chunk. + save (bool, optional): If True, save the incorrect chunks to a file. Defaults to False. + + Returns: + None + + Raises: + None + """ incorrect = { "too_small": [ { diff --git a/src/utils/decorators.py b/src/utils/decorators.py index 5e5a751..274fd05 100644 --- a/src/utils/decorators.py +++ b/src/utils/decorators.py @@ -21,6 +21,19 @@ def base_error_handler(func: Callable) -> Callable: + """ + Return a decorator that wraps a function with error handling and logging. + + Args: + func (Callable): The function to be wrapped with error handling. + + Returns: + Callable: The wrapped function with error handling. + + Raises: + Exception: Re-raises the original exception after logging the error. + """ + @functools.wraps(func) def wrapper(*args, **kwargs) -> Any: # Access the logger @@ -36,6 +49,8 @@ def wrapper(*args, **kwargs) -> Any: def application_level_handler(func: Callable) -> Callable: + """Retrieve the application logger.""" + @functools.wraps(func) def wrapper(*args, **kwargs) -> Any: # Access the logger @@ -57,6 +72,28 @@ def wrapper(*args, **kwargs) -> Any: def anthropic_error_handler(func: Callable) -> Callable: + """ + Apply error handling for various exceptions encountered in Anthropic API calls. + + Args: + func (Callable): The function to be wrapped with error handling. + + Returns: + Callable: A wrapper function that includes error handling. + + Raises: + AuthenticationError: If authentication fails. + BadRequestError: If the request is invalid. + PermissionDeniedError: If permission is denied. + NotFoundError: If the resource is not found. + RateLimitError: If the rate limit is exceeded. + APIConnectionError: For API connection issues, including timeout errors. + InternalServerError: If there's an internal server error. + APIError: For unexpected API errors. + AnthropicError: For unexpected Anthropic-specific errors. + Exception: For any other unexpected errors. + """ + @functools.wraps(func) def wrapper(*args, **kwargs) -> Any: # Access the logger @@ -102,6 +139,16 @@ def wrapper(*args, **kwargs) -> Any: def performance_logger(func: Callable) -> Callable: + """ + Decorate a function to log its execution time. + + Args: + func (Callable): The function to be decorated. + + Returns: + Callable: The wrapped function with logging of execution time. + """ + @functools.wraps(func) def wrapper(*args, **kwargs) -> Any: # Access the logger diff --git a/src/utils/logger.py b/src/utils/logger.py index dd86faa..ea39bca 100644 --- a/src/utils/logger.py +++ b/src/utils/logger.py @@ -11,6 +11,19 @@ class ColoredFormatter(logging.Formatter): + """ + Enhance log messages with colors and emojis based on their severity levels. + + Args: + logging (module): A logging module instance for handling logs. + + Returns: + None + + Raises: + KeyError: If a log level is not found in COLORS or EMOJIS dictionaries. + """ + COLORS = { logging.DEBUG: Fore.BLUE, logging.INFO: Fore.LIGHTCYAN_EX, @@ -28,6 +41,19 @@ class ColoredFormatter(logging.Formatter): } def format(self, record): + """ + Format the log record with emoji and color based on log level. + + Args: + record (LogRecord): The log record to format. + + Returns: + str: The formatted log message including color and emoji. + + Raises: + KeyError: If a log level key does not exist in COLORS or EMOJIS. + + """ # Get the original log message log_message = super().format(record) @@ -43,11 +69,17 @@ def format(self, record): def configure_logging(debug=False, log_file="app.log"): """ - Configures the logging system. + Configure the application's logging system. + + Args: + debug (bool): Whether to set the logging level to debug. Defaults to False. + log_file (str): The name of the file to log to. Defaults to "app.log". + + Returns: + None - Parameters: - - debug (bool): If True, sets the logging level to DEBUG. - - log_file (str): The name of the log file. + Raises: + Exception: Any exception that logging handlers or the file system might raise. """ log_level = logging.DEBUG if debug else logging.INFO app_logger = logging.getLogger("kollektiv") @@ -79,7 +111,16 @@ def configure_logging(debug=False, log_file="app.log"): def get_logger(): """ - Retrieves a logger with the 'kollektiv' prefix based on the caller's module. + Retrieve a logger named after the calling module. + + Args: + None + + Returns: + logging.Logger: A logger specifically named for the module calling the function. + + Raises: + None """ import inspect diff --git a/src/utils/output_formatter.py b/src/utils/output_formatter.py index 6b292d1..26dd36b 100644 --- a/src/utils/output_formatter.py +++ b/src/utils/output_formatter.py @@ -8,13 +8,47 @@ def print_assistant_stream(message: str, end: str = "\n", flush: bool = True): + """ + Print a message in the assistant's color stream. + + Args: + message (str): The message to be printed. + end (str, optional): The string appended after the last value, default is a newline. + flush (bool, optional): Whether to forcibly flush the stream, default is True. + + Returns: + None + + """ print(f"{ASSISTANT_COLOR}{message}{Style.RESET_ALL}", end=end, flush=flush) def print_welcome_message(message: str): + """ + Print the welcome message to the console. + + Args: + message (str): The welcome message to be printed. + + Returns: + None + + """ print(f"\n{ASSISTANT_COLOR}{message}{Style.RESET_ALL}") def user_input(): + """ + Prompt the user for input and return the input string. + + Args: + None + + Returns: + str: The input provided by the user. + + Raises: + None + """ user_input = input(f"{USER_COLOR}{Style.BRIGHT}You:{Style.RESET_ALL} ") return user_input diff --git a/src/vector_storage/vector_db.py b/src/vector_storage/vector_db.py index fd8da56..3355c04 100644 --- a/src/vector_storage/vector_db.py +++ b/src/vector_storage/vector_db.py @@ -21,10 +21,37 @@ class DocumentProcessor: + """ + Process and manage document data. + + Args: + filename (str): The name of the JSON file to load. + + Returns: + list[dict]: A list of dictionaries containing the JSON data. + + Raises: + FileNotFoundError: If the file cannot be found at the specified path. + json.JSONDecodeError: If the file contains invalid JSON. + """ + def __init__(self): self.processed_dir = PROCESSED_DATA_DIR def load_json(self, filename: str) -> list[dict]: + """ + Load and parse JSON data from a specified file. + + Args: + filename (str): Name of the file containing JSON data. + + Returns: + list[dict]: A list of dictionaries parsed from the JSON file. + + Raises: + FileNotFoundError: If the specified file cannot be found. + JSONDecodeError: If the file contains invalid JSON. + """ try: filepath = os.path.join(self.processed_dir, filename) with open(filepath) as f: @@ -39,32 +66,112 @@ def load_json(self, filename: str) -> list[dict]: class VectorDBInterface(ABC): + """Define an interface for vector database operations.""" + @abstractmethod def prepare_documents(self, chunks: list[dict[str, Any]]) -> dict[str, list[str]]: + """ + Prepare documents from given chunks. + + Args: + chunks (list[dict[str, Any]]): A list of dictionaries where each dictionary contains various + data attributes. + + Returns: + dict[str, list[str]]: A dictionary where each key is a document identifier and the value is a list of + processed text elements. + + Raises: + ValueError: If the input chunks are not in the expected format. + """ pass @abstractmethod def add_documents(self, processed_docs: dict[str, list[str]]) -> None: + """Add processed documents to the data store. + + Args: + processed_docs (dict[str, list[str]]): A dictionary where the key is a document identifier and the value is + a list of processed text elements for that document. + + Returns: + None + + Raises: + NotImplementedError: If the method is not implemented by a subclass. + """ pass @abstractmethod def query(self, user_query: str | list[str], n_results: int = 10) -> dict[str, Any]: + """ + Perform a query and return the results. + + Args: + user_query (str | list[str]): The query string(s) to search for. + n_results (int, optional): The number of results to return. Defaults to 10. + + Returns: + dict[str, Any]: The query results in a dictionary format. + + Raises: + NotImplementedError: If the method is not implemented in a subclass. + """ pass @abstractmethod def reset_database(self) -> None: + """ + Reset the database to its initial state. + + Raises: + NotImplementedError: If the method is not implemented in a subclass. + """ pass @abstractmethod def deduplicate_documents(self, search_results: dict[str, Any]) -> dict[str, Any]: + """ + Handle the deduplication of search results. + + :param search_results: A dictionary containing search results where + keys are document IDs and values are document data. + The document data is represented as any type. + + :return: A dictionary with duplicate documents removed. The keys are + document IDs and values are the deduplicated document data. + + :raises ValueError: If the input search results are not a dictionary. + """ pass @abstractmethod def check_documents_exist(self, document_ids: list[str]) -> tuple[bool, list[str]]: + """ + Check if the given documents exist in the database. + + Args: + document_ids (list[str]): A list of document IDs to check. + + Returns: + tuple[bool, list[str]]: A tuple containing a boolean indicating if all documents exist, + and a list of IDs that do not exist. + + Raises: + ValueError: If the document_ids list is empty. + """ pass class VectorDB(VectorDBInterface): + """ + Initializes the VectorDB class with embedding and OpenAI API configurations. + + Args: + embedding_function (str): The name of the embedding function to use. Defaults to "text-embedding-3-small". + openai_api_key (str): The OpenAI API key for authentication. Defaults to OPENAI_API_KEY. + """ + def __init__( self, embedding_function: str = "text-embedding-3-small", @@ -94,6 +201,19 @@ def _init(self): ) def prepare_documents(self, chunks: list[dict]) -> dict[str, list[str]]: + """ + Prepare documents by extracting and combining headers and content. + + Args: + chunks (list[dict]): A list of dictionaries where each dictionary contains the chunk data. + + Returns: + dict[str, list[str]]: A dictionary with keys 'ids', 'documents', and 'metadatas' each containing a list. + + Raises: + KeyError: If the required keys are missing from the dictionaries in chunks. + TypeError: If the input is not a list of dictionaries. + """ ids = [] documents = [] metadatas = [] # used for filtering @@ -124,6 +244,19 @@ def prepare_documents(self, chunks: list[dict]) -> dict[str, list[str]]: @base_error_handler def add_documents(self, json_data: list[dict], file_name: str) -> None: + """ + Add documents from a given JSON list to the database, handling duplicates and generating summaries. + + Args: + json_data (list[dict]): A list of dictionaries containing the document data. + file_name (str): The name of the file from which the documents are being added. + + Returns: + None + + Raises: + Exception: If there is an error during the document preparation or addition process. + """ processed_docs = self.prepare_documents(json_data) ids = processed_docs["ids"] @@ -154,7 +287,19 @@ def add_documents(self, json_data: list[dict], file_name: str) -> None: @base_error_handler def check_documents_exist(self, document_ids: list[str]) -> tuple[bool, list[str]]: - """Checks if chunks are already added to the database based on chunk ids""" + """ + Check if documents exist. + + Args: + document_ids (list[str]): The list of document IDs to check. + + Returns: + tuple[bool, list[str]]: A tuple where the first element is a boolean indicating if all documents exist, + and the second element is a list of missing document IDs. + + Raises: + Exception: If an error occurs while checking the document existence. + """ try: # Get existing document ids from the db result = self.collection.get(ids=document_ids, include=[]) @@ -175,7 +320,17 @@ def check_documents_exist(self, document_ids: list[str]) -> tuple[bool, list[str @base_error_handler def query(self, user_query: str | list[str], n_results: int = 10): """ - Handles both a single query and multiple queries + Query the collection to retrieve documents based on the user's query. + + Args: + user_query (str | list[str]): A string or list of strings representing the user's query. + n_results (int, optional): The number of results to retrieve. Defaults to 10. + + Returns: + list: A list of search results matching the query. + + Raises: + SomeSpecificException: If an error occurs while querying the collection. """ query_texts = [user_query] if isinstance(user_query, str) else user_query search_results = self.collection.query( @@ -185,6 +340,18 @@ def query(self, user_query: str | list[str], n_results: int = 10): @base_error_handler def reset_database(self): + """ + Reset the database by deleting and recreating the collection, and clearing summaries. + + Args: + self: The instance of the class containing this method. + + Returns: + None + + Raises: + Exception: If there is an error while deleting or creating the collection, or clearing summaries. + """ # Delete collection self.client.delete_collection(self.collection_name) @@ -198,6 +365,15 @@ def reset_database(self): logger.info("Database reset successfully. ") def process_results_to_print(self, search_results: dict[str, Any]): + """ + Process search results to a formatted string output. + + Args: + search_results (dict[str, Any]): The search results containing documents and distances. + + Returns: + list[str]: A list of formatted strings containing distances and corresponding documents. + """ documents = search_results["documents"][0] distances = search_results["distances"][0] @@ -207,6 +383,18 @@ def process_results_to_print(self, search_results: dict[str, Any]): return output def deduplicate_documents(self, search_results: dict[str, Any]) -> dict[str, Any]: + """ + Remove duplicate documents from search results based on unique chunk IDs. + + Args: + search_results (dict[str, Any]): A dictionary containing lists of documents, distances, and IDs. + + Returns: + dict[str, Any]: A dictionary of unique documents with their corresponding text and distance. + + Raises: + None. + """ documents = search_results["documents"][0] distances = search_results["distances"][0] ids = search_results["ids"][0] @@ -220,6 +408,14 @@ def deduplicate_documents(self, search_results: dict[str, Any]) -> dict[str, Any class Reranker: + """ + Initializes and manages a Cohere Client for document re-ranking. + + Args: + cohere_api_key (str): API key for the Cohere service. + model_name (str): Name of the model to use for re-ranking. Defaults to "rerank-english-v3.0". + """ + def __init__(self, cohere_api_key: str = COHERE_API_KEY, model_name: str = "rerank-english-v3.0"): self.cohere_api_key = cohere_api_key self.model_name = model_name @@ -237,7 +433,15 @@ def _init(self): def extract_documents_list(self, unique_documents: dict[str, Any]) -> list[str]: """ - Prepares the unique documents list for processing by Cohere. + Extract the 'text' field from each unique document. + + Args: + unique_documents (dict[str, Any]): A dictionary where each value is a document represented as a dictionary + with a 'text' field. + + Returns: + list[str]: A list containing the 'text' field from each document. + """ # extract the 'text' field from each unique document document_texts = [chunk["text"] for chunk in unique_documents.values()] @@ -245,9 +449,19 @@ def extract_documents_list(self, unique_documents: dict[str, Any]) -> list[str]: def rerank(self, query: str, documents: dict[str, Any], return_documents=True) -> RerankResponse: """ - Use Cohere rerank API to score and rank documents based on the query. - Excludes irrelevant documents. - :return: list of documents with relevance scores + Rerank a list of documents based on their relevance to a given query. + + Args: + query (str): The search query to rank the documents against. + documents (dict[str, Any]): A dictionary containing documents to be reranked. + return_documents (bool): A flag indicating whether to return the full documents. Defaults to True. + + Returns: + RerankResponse: The reranked list of documents and their relevance scores. + + Raises: + SomeSpecificException: If an error occurs during the reranking process. + """ # extract list of documents document_texts = self.extract_documents_list(documents) @@ -258,16 +472,18 @@ def rerank(self, query: str, documents: dict[str, Any], return_documents=True) - ) logger.debug(f"Received {len(response.results)} documents from Cohere.") - - # filter irrelevant results - # logger.debug(f"Filtering out the results with less than {relevance_threshold} relevance " f"score") - # relevant_results = self.filter_irrelevant_results(response, relevance_threshold) - # logger.debug(f"{len(relevant_results)} documents remaining after re-ranking.") - return response class ResultRetriever: + """ + Initializes the ResultRetriever with a vector database and a reranker. + + Args: + vector_db (VectorDB): The vector database used for querying documents. + reranker (Reranker): The reranker used for reranking documents. + """ + def __init__(self, vector_db: VectorDB, reranker: Reranker): self.db = vector_db self.reranker = reranker @@ -275,9 +491,21 @@ def __init__(self, vector_db: VectorDB, reranker: Reranker): @base_error_handler @weave.op() def retrieve(self, user_query: str, combined_queries: list[str], top_n: int = None): - """Returns ranked documents based on the user query: - top_n: The number of most relevant documents or indices to return, defaults to the length of the documents""" + """ + Retrieve and rank documents based on user query and combined queries. + Args: + user_query (str): The primary user query for retrieving documents. + combined_queries (list[str]): A list of queries to combine for document retrieval. + top_n (int, optional): The maximum number of top documents to return. Defaults to None. + + Returns: + list: A list of limited, ranked, and relevant documents. + + Raises: + DatabaseError: If there is an issue querying the database. + RerankError: If there is an issue with reranking the documents. + """ start_time = time.time() # Start timing # get expanded search results @@ -304,7 +532,20 @@ def retrieve(self, user_query: str, combined_queries: list[str], top_n: int = No def filter_irrelevant_results( self, response: RerankResponse, relevance_threshold: float = 0.1 ) -> dict[int, dict[str, int | float | str]]: - """Filters out irrelevant result from Cohere reranking""" + """ + Filter out results below a certain relevance threshold. + + Args: + response (RerankResponse): The response containing the reranked results. + relevance_threshold (float): The minimum relevance score required. Defaults to 0.1. + + Returns: + dict[int, dict[str, int | float | str]]: A dictionary of relevant results with their index, text, + and relevance score. + + Raises: + None + """ relevant_results = {} for result in response.results: @@ -322,8 +563,19 @@ def filter_irrelevant_results( return relevant_results def limit_results(self, ranked_documents: dict[str, Any], top_n: int = None) -> dict[str, Any]: - """Takes re-ranked documents and returns n results""" + """ + Limit the number of results based on the given top_n parameter. + Args: + ranked_documents (dict[str, Any]): A dictionary of documents with relevance scores. + top_n (int, optional): The number of top results to return. Defaults to None. + + Returns: + dict[str, Any]: The dictionary containing the top N ranked documents, or all documents if top_n is None. + + Raises: + ValueError: If top_n is specified and is less than zero. + """ if top_n is not None and top_n < len(ranked_documents): # Sort the items by relevance score in descending order sorted_items = sorted(ranked_documents.items(), key=lambda x: x[1]["relevance_score"], reverse=True) @@ -342,6 +594,19 @@ def limit_results(self, ranked_documents: dict[str, Any], top_n: int = None) -> def main(): + """ + Configure logging, reset the vector database, process JSON documents, and add them to the database. + + Args: + None + + Returns: + None + + Raises: + FileNotFoundError: If the specified JSON file does not exist. + ValueError: If the JSON file is malformed or contains invalid data. + """ configure_logging() vector_db = VectorDB() vector_db.reset_database() diff --git a/tests/test_app.py b/tests/test_app.py index 5357be1..6021026 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -3,6 +3,12 @@ def test_app_initialization(): + """ + Test the initialization of application components. + + Ensures that instances of VectorDB, ClaudeAssistant, Reranker, and ResultRetriever + are successfully created and are not None. + """ vector_db = VectorDB() claude_assistant = ClaudeAssistant(vector_db) reranker = Reranker() diff --git a/tests/test_chunking.py b/tests/test_chunking.py index aedf50f..2a4a1cc 100644 --- a/tests/test_chunking.py +++ b/tests/test_chunking.py @@ -2,6 +2,17 @@ def test_markdown_chunker_initialization(): + """ + Test the initialization of the MarkdownChunker class. + + Asserts: + - The chunker instance is not None. + - The max_tokens attribute is set to 1000. + - The soft_token_limit attribute is set to 800. + + Raises: + AssertionError: If any of the assertions fail. + """ chunker = MarkdownChunker(input_filename="test_input.json") assert chunker is not None assert chunker.max_tokens == 1000 diff --git a/tests/test_claude_assistant.py b/tests/test_claude_assistant.py index 7cb48b2..6c5aa89 100644 --- a/tests/test_claude_assistant.py +++ b/tests/test_claude_assistant.py @@ -9,18 +9,24 @@ @pytest.fixture def mock_vector_db(): - """Fixture to mock the vector database dependency.""" + """Create a mock object for VectorDB. + + Returns: + MagicMock: A mock object that mimics the behavior of VectorDB. + """ return MagicMock(spec=VectorDB) @pytest.fixture def claude_assistant(mock_vector_db): """ - Fixture to create an instance of ClaudeAssistant with mocked dependencies. + Set up a ClaudeAssistant instance with mocked dependencies. + + Args: + mock_vector_db: A mock of the VectorDB class used for testing. - Patches: - - anthropic.Anthropic: Mocked to prevent actual API calls. - - tiktoken.get_encoding: Mocked encoding function. + Returns: + ClaudeAssistant: An instance of ClaudeAssistant with dependencies mocked. """ with patch("anthropic.Anthropic") as mock_anthropic, patch("tiktoken.get_encoding") as mock_encoding: mock_client = Mock() @@ -40,7 +46,16 @@ def claude_assistant(mock_vector_db): def test_streaming_response(claude_assistant): """ - Test the stream_response method to ensure it correctly handles streaming responses. + Test the streaming response from the Claude assistant. + + Args: + claude_assistant: An instance of the ClaudeAssistant to be tested. + + Returns: + None + + Raises: + AssertionError: If the response does not match the expected output. """ # Create a MagicMock stream object with iterable mock messages mock_stream = MagicMock() @@ -81,7 +96,16 @@ def test_streaming_response(claude_assistant): def test_non_streaming_response(claude_assistant): """ - Test the not_stream_response method to ensure it correctly handles non-streaming responses. + Test the non-streaming response of the Claude Assistant. + + Args: + claude_assistant: An instance of the Claude Assistant containing the tested method. + + Returns: + None. The function performs assertions to validate the behavior. + + Raises: + AssertionError: If any of the assertions fail. """ # Create a mock response object mock_response = Mock( @@ -116,7 +140,16 @@ def test_non_streaming_response(claude_assistant): def test_conversation_history_handling(claude_assistant): """ - Test the conversation history management within the assistant. + Test the handling and retrieval of conversation history for the assistant. + + Args: + claude_assistant: An instance of the assistant with conversation history functionality. + + Returns: + None + + Raises: + AssertionError: If the conversation history does not meet the expected conditions. """ # Add messages to the conversation history claude_assistant.conversation_history.add_message("user", "Hello") diff --git a/tests/test_crawler.py b/tests/test_crawler.py index d6b94dc..a68559f 100644 --- a/tests/test_crawler.py +++ b/tests/test_crawler.py @@ -3,6 +3,18 @@ def test_fire_crawler_initialization(): + """ + Test the initialization of the FireCrawler class. + + Args: + None + + Returns: + None + + Raises: + AssertionError: If the assertions regarding the FireCrawler instance fail. + """ crawler = FireCrawler(FIRECRAWL_API_KEY) assert crawler is not None assert crawler.api_key == FIRECRAWL_API_KEY diff --git a/tests/test_vector_db.py b/tests/test_vector_db.py index 232113c..e2d90aa 100644 --- a/tests/test_vector_db.py +++ b/tests/test_vector_db.py @@ -2,11 +2,36 @@ def test_vector_db_initialization(): + """ + Test the initialization of the VectorDB class. + + Ensures that an instance of VectorDB is created and the collection_name + is set to the default value "local-collection". + + Returns: + bool: True if the tests pass, otherwise raises an assertion error. + + """ vector_db = VectorDB() assert vector_db is not None assert vector_db.collection_name == "local-collection" def test_document_processor_initialization(): + """ + Test the initialization of the DocumentProcessor class. + + Checks whether an instance of DocumentProcessor is successfully created + and is not None. + + Args: + None + + Returns: + None + + Raises: + AssertionError: If the DocumentProcessor instance is None. + """ processor = DocumentProcessor() assert processor is not None