From d32ae1a0b17744599c6e54725995a6eb5d62a85f Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Sat, 7 Sep 2024 02:12:22 +0530 Subject: [PATCH] Add support for anthropic (#1819) --- docs/components/llms/models/anthropic.mdx | 2 +- mem0/llms/anthropic.py | 67 +++++++++++++++++++++++ mem0/llms/aws_bedrock.py | 11 +--- mem0/llms/azure_openai.py | 2 +- mem0/llms/groq.py | 12 +--- mem0/llms/litellm.py | 13 +---- mem0/llms/ollama.py | 12 +--- mem0/llms/openai.py | 2 +- mem0/llms/together.py | 15 +---- mem0/llms/utils/tools.py | 2 + mem0/utils/factory.py | 1 + mem0/vector_stores/chroma.py | 13 +---- mem0/vector_stores/pgvector.py | 14 +---- 13 files changed, 82 insertions(+), 84 deletions(-) create mode 100644 mem0/llms/anthropic.py diff --git a/docs/components/llms/models/anthropic.mdx b/docs/components/llms/models/anthropic.mdx index 916c88d2dd..bb5f8c2d16 100644 --- a/docs/components/llms/models/anthropic.mdx +++ b/docs/components/llms/models/anthropic.mdx @@ -11,7 +11,7 @@ os.environ["ANTHROPIC_API_KEY"] = "your-api-key" config = { "llm": { - "provider": "litellm", + "provider": "anthropic", "config": { "model": "claude-3-opus-20240229", "temperature": 0.1, diff --git a/mem0/llms/anthropic.py b/mem0/llms/anthropic.py new file mode 100644 index 0000000000..5e32d81c96 --- /dev/null +++ b/mem0/llms/anthropic.py @@ -0,0 +1,67 @@ +import subprocess +import sys +import os +import json +from typing import Dict, List, Optional + +try: + import anthropic +except ImportError: + raise ImportError("The 'anthropic' library is required. Please install it using 'pip install anthropic'.") + +from mem0.llms.base import LLMBase +from mem0.configs.llms.base import BaseLlmConfig + + +class AnthropicLLM(LLMBase): + def __init__(self, config: Optional[BaseLlmConfig] = None): + super().__init__(config) + + if not self.config.model: + self.config.model = "claude-3-5-sonnet-20240620" + + api_key = self.config.api_key or os.getenv("ANTHROPIC_API_KEY") + self.client = anthropic.Anthropic(api_key=api_key) + + def generate_response( + self, + messages: List[Dict[str, str]], + response_format=None, + tools: Optional[List[Dict]] = None, + tool_choice: str = "auto", + ): + """ + Generate a response based on the given messages using Anthropic. + + Args: + messages (list): List of message dicts containing 'role' and 'content'. + response_format (str or object, optional): Format of the response. Defaults to "text". + tools (list, optional): List of tools that the model can call. Defaults to None. + tool_choice (str, optional): Tool choice method. Defaults to "auto". + + Returns: + str: The generated response. + """ + # Separate system message from other messages + system_message = "" + filtered_messages = [] + for message in messages: + if message['role'] == 'system': + system_message = message['content'] + else: + filtered_messages.append(message) + + params = { + "model": self.config.model, + "messages": filtered_messages, + "system": system_message, + "temperature": self.config.temperature, + "max_tokens": self.config.max_tokens, + "top_p": self.config.top_p, + } + if tools: # TODO: Remove tools if no issues found with new memory addition logic + params["tools"] = tools + params["tool_choice"] = tool_choice + + response = self.client.messages.create(**params) + return response.content[0].text diff --git a/mem0/llms/aws_bedrock.py b/mem0/llms/aws_bedrock.py index e2c1ea0f66..b5699a9efc 100644 --- a/mem0/llms/aws_bedrock.py +++ b/mem0/llms/aws_bedrock.py @@ -7,16 +7,7 @@ try: import boto3 except ImportError: - user_input = input("The 'boto3' library is required. Install it now? [y/N]: ") - if user_input.lower() == 'y': - try: - subprocess.check_call([sys.executable, "-m", "pip", "install", "boto3"]) - import boto3 - except subprocess.CalledProcessError: - print("Failed to install 'boto3'. Please install it manually using 'pip install boto3'") - sys.exit(1) - else: - raise ImportError("The required 'boto3' library is not installed.") + raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.") from mem0.llms.base import LLMBase from mem0.configs.llms.base import BaseLlmConfig diff --git a/mem0/llms/azure_openai.py b/mem0/llms/azure_openai.py index 520631ab14..d18569a8e7 100644 --- a/mem0/llms/azure_openai.py +++ b/mem0/llms/azure_openai.py @@ -87,7 +87,7 @@ def generate_response( } if response_format: params["response_format"] = response_format - if tools: + if tools: # TODO: Remove tools if no issues found with new memory addition logic params["tools"] = tools params["tool_choice"] = tool_choice diff --git a/mem0/llms/groq.py b/mem0/llms/groq.py index 0163f2c161..c67436c081 100644 --- a/mem0/llms/groq.py +++ b/mem0/llms/groq.py @@ -7,16 +7,7 @@ try: from groq import Groq except ImportError: - user_input = input("The 'groq' library is required. Install it now? [y/N]: ") - if user_input.lower() == 'y': - try: - subprocess.check_call([sys.executable, "-m", "pip", "install", "groq"]) - from groq import Groq - except subprocess.CalledProcessError: - print("Failed to install 'groq'. Please install it manually using 'pip install groq'.") - sys.exit(1) - else: - raise ImportError("The required 'groq' library is not installed.") + raise ImportError("The 'groq' library is required. Please install it using 'pip install groq'.") from mem0.llms.base import LLMBase from mem0.configs.llms.base import BaseLlmConfig @@ -28,7 +19,6 @@ def __init__(self, config: Optional[BaseLlmConfig] = None): if not self.config.model: self.config.model = "llama3-70b-8192" - self.client = Groq() api_key = self.config.api_key or os.getenv("GROQ_API_KEY") self.client = Groq(api_key=api_key) diff --git a/mem0/llms/litellm.py b/mem0/llms/litellm.py index 9d910370dc..1a6fb7c0c0 100644 --- a/mem0/llms/litellm.py +++ b/mem0/llms/litellm.py @@ -6,16 +6,7 @@ try: import litellm except ImportError: - user_input = input("The 'litellm' library is required. Install it now? [y/N]: ") - if user_input.lower() == 'y': - try: - subprocess.check_call([sys.executable, "-m", "pip", "install", "litellm"]) - import litellm - except subprocess.CalledProcessError: - print("Failed to install 'litellm'. Please install it manually using 'pip install litellm'.") - sys.exit(1) - else: - raise ImportError("The required 'litellm' library is not installed.") + raise ImportError("The 'litellm' library is required. Please install it using 'pip install litellm'.") from mem0.llms.base import LLMBase from mem0.configs.llms.base import BaseLlmConfig @@ -91,7 +82,7 @@ def generate_response( } if response_format: params["response_format"] = response_format - if tools: + if tools: # TODO: Remove tools if no issues found with new memory addition logic params["tools"] = tools params["tool_choice"] = tool_choice diff --git a/mem0/llms/ollama.py b/mem0/llms/ollama.py index e86cac5ea0..e7acbdd7b8 100644 --- a/mem0/llms/ollama.py +++ b/mem0/llms/ollama.py @@ -5,17 +5,7 @@ try: from ollama import Client except ImportError: - user_input = input("The 'ollama' library is required. Install it now? [y/N]: ") - if user_input.lower() == 'y': - try: - subprocess.check_call([sys.executable, "-m", "pip", "install", "ollama"]) - from ollama import Client - except subprocess.CalledProcessError: - print("Failed to install 'ollama'. Please install it manually using 'pip install ollama'.") - sys.exit(1) - else: - print("The required 'ollama' library is not installed.") - sys.exit(1) + raise ImportError("The 'ollama' library is required. Please install it using 'pip install ollama'.") from mem0.llms.base import LLMBase from mem0.configs.llms.base import BaseLlmConfig diff --git a/mem0/llms/openai.py b/mem0/llms/openai.py index 4900510678..68a3634531 100644 --- a/mem0/llms/openai.py +++ b/mem0/llms/openai.py @@ -100,7 +100,7 @@ def generate_response( if response_format: params["response_format"] = response_format - if tools: + if tools: # TODO: Remove tools if no issues found with new memory addition logic params["tools"] = tools params["tool_choice"] = tool_choice diff --git a/mem0/llms/together.py b/mem0/llms/together.py index fa477217d3..816ed7a99a 100644 --- a/mem0/llms/together.py +++ b/mem0/llms/together.py @@ -7,17 +7,7 @@ try: from together import Together except ImportError: - user_input = input("The 'together' library is required. Install it now? [y/N]: ") - if user_input.lower() == 'y': - try: - subprocess.check_call([sys.executable, "-m", "pip", "install", "together"]) - from together import Together - except subprocess.CalledProcessError: - print("Failed to install 'together'. Please install it manually using 'pip install together'.") - sys.exit(1) - else: - print("The required 'together' library is not installed.") - sys.exit(1) + raise ImportError("The 'together' library is required. Please install it using 'pip install together'.") from mem0.llms.base import LLMBase from mem0.configs.llms.base import BaseLlmConfig @@ -29,7 +19,6 @@ def __init__(self, config: Optional[BaseLlmConfig] = None): if not self.config.model: self.config.model = "mistralai/Mixtral-8x7B-Instruct-v0.1" - self.client = Together() api_key = self.config.api_key or os.getenv("TOGETHER_API_KEY") self.client = Together(api_key=api_key) @@ -92,7 +81,7 @@ def generate_response( } if response_format: params["response_format"] = response_format - if tools: + if tools: # TODO: Remove tools if no issues found with new memory addition logic params["tools"] = tools params["tool_choice"] = tool_choice diff --git a/mem0/llms/utils/tools.py b/mem0/llms/utils/tools.py index 50031e361f..fb4ff4a2a5 100644 --- a/mem0/llms/utils/tools.py +++ b/mem0/llms/utils/tools.py @@ -1,3 +1,5 @@ +# TODO: Remove these tools if no issues are found for new memory addition logic + ADD_MEMORY_TOOL = { "type": "function", "function": { diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index bf012b420c..c672cf1d6e 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -20,6 +20,7 @@ class LlmFactory: "litellm": "mem0.llms.litellm.LiteLLM", "azure_openai": "mem0.llms.azure_openai.AzureOpenAILLM", "openai_structured": "mem0.llms.openai_structured.OpenAIStructuredLLM", + "anthropic": "mem0.llms.anthropic.AnthropicLLM" } @classmethod diff --git a/mem0/vector_stores/chroma.py b/mem0/vector_stores/chroma.py index 4904aacf59..21fc85c686 100644 --- a/mem0/vector_stores/chroma.py +++ b/mem0/vector_stores/chroma.py @@ -9,18 +9,7 @@ import chromadb from chromadb.config import Settings except ImportError: - user_input = input("The 'chromadb' library is required. Install it now? [y/N]: ") - if user_input.lower() == 'y': - try: - subprocess.check_call([sys.executable, "-m", "pip", "install", "chromadb"]) - import chromadb - from chromadb.config import Settings - except subprocess.CalledProcessError: - print("Failed to install 'chromadb'. Please install it manually using 'pip install chromadb'.") - sys.exit(1) - else: - print("The required 'chromadb' library is not installed.") - sys.exit(1) + raise ImportError("The 'chromadb' library is required. Please install it using 'pip install chromadb'.") from mem0.vector_stores.base import VectorStoreBase diff --git a/mem0/vector_stores/pgvector.py b/mem0/vector_stores/pgvector.py index 7f8c2159d8..dae1e2d189 100644 --- a/mem0/vector_stores/pgvector.py +++ b/mem0/vector_stores/pgvector.py @@ -9,19 +9,7 @@ import psycopg2 from psycopg2.extras import execute_values except ImportError: - user_input = input("The 'psycopg2' library is required. Install it now? [y/N]: ") - if user_input.lower() == 'y': - try: - subprocess.check_call([sys.executable, "-m", "pip", "install", "psycopg2"]) - import psycopg2 - from psycopg2.extras import execute_values - except subprocess.CalledProcessError: - print("Failed to install 'psycopg2'. Please install it manually using 'pip install psycopg2'.") - sys.exit(1) - else: - print("The required 'psycopg2' library is not installed.") - sys.exit(1) - + raise ImportError("The 'psycopg2' library is required. Please install it using 'pip install psycopg2'.") from mem0.vector_stores.base import VectorStoreBase