Skip to content

Commit

Permalink
Add support for anthropic (#1819)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dev-Khant authored Sep 6, 2024
1 parent 965f7a3 commit d32ae1a
Show file tree
Hide file tree
Showing 13 changed files with 82 additions and 84 deletions.
2 changes: 1 addition & 1 deletion docs/components/llms/models/anthropic.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ os.environ["ANTHROPIC_API_KEY"] = "your-api-key"

config = {
"llm": {
"provider": "litellm",
"provider": "anthropic",
"config": {
"model": "claude-3-opus-20240229",
"temperature": 0.1,
Expand Down
67 changes: 67 additions & 0 deletions mem0/llms/anthropic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import subprocess
import sys
import os
import json
from typing import Dict, List, Optional

try:
import anthropic
except ImportError:
raise ImportError("The 'anthropic' library is required. Please install it using 'pip install anthropic'.")

from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig


class AnthropicLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)

if not self.config.model:
self.config.model = "claude-3-5-sonnet-20240620"

api_key = self.config.api_key or os.getenv("ANTHROPIC_API_KEY")
self.client = anthropic.Anthropic(api_key=api_key)

def generate_response(
self,
messages: List[Dict[str, str]],
response_format=None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
"""
Generate a response based on the given messages using Anthropic.
Args:
messages (list): List of message dicts containing 'role' and 'content'.
response_format (str or object, optional): Format of the response. Defaults to "text".
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns:
str: The generated response.
"""
# Separate system message from other messages
system_message = ""
filtered_messages = []
for message in messages:
if message['role'] == 'system':
system_message = message['content']
else:
filtered_messages.append(message)

params = {
"model": self.config.model,
"messages": filtered_messages,
"system": system_message,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p,
}
if tools: # TODO: Remove tools if no issues found with new memory addition logic
params["tools"] = tools
params["tool_choice"] = tool_choice

response = self.client.messages.create(**params)
return response.content[0].text
11 changes: 1 addition & 10 deletions mem0/llms/aws_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,7 @@
try:
import boto3
except ImportError:
user_input = input("The 'boto3' library is required. Install it now? [y/N]: ")
if user_input.lower() == 'y':
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", "boto3"])
import boto3
except subprocess.CalledProcessError:
print("Failed to install 'boto3'. Please install it manually using 'pip install boto3'")
sys.exit(1)
else:
raise ImportError("The required 'boto3' library is not installed.")
raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.")

from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
Expand Down
2 changes: 1 addition & 1 deletion mem0/llms/azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def generate_response(
}
if response_format:
params["response_format"] = response_format
if tools:
if tools: # TODO: Remove tools if no issues found with new memory addition logic
params["tools"] = tools
params["tool_choice"] = tool_choice

Expand Down
12 changes: 1 addition & 11 deletions mem0/llms/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,7 @@
try:
from groq import Groq
except ImportError:
user_input = input("The 'groq' library is required. Install it now? [y/N]: ")
if user_input.lower() == 'y':
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", "groq"])
from groq import Groq
except subprocess.CalledProcessError:
print("Failed to install 'groq'. Please install it manually using 'pip install groq'.")
sys.exit(1)
else:
raise ImportError("The required 'groq' library is not installed.")
raise ImportError("The 'groq' library is required. Please install it using 'pip install groq'.")

from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
Expand All @@ -28,7 +19,6 @@ def __init__(self, config: Optional[BaseLlmConfig] = None):

if not self.config.model:
self.config.model = "llama3-70b-8192"
self.client = Groq()

api_key = self.config.api_key or os.getenv("GROQ_API_KEY")
self.client = Groq(api_key=api_key)
Expand Down
13 changes: 2 additions & 11 deletions mem0/llms/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,7 @@
try:
import litellm
except ImportError:
user_input = input("The 'litellm' library is required. Install it now? [y/N]: ")
if user_input.lower() == 'y':
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", "litellm"])
import litellm
except subprocess.CalledProcessError:
print("Failed to install 'litellm'. Please install it manually using 'pip install litellm'.")
sys.exit(1)
else:
raise ImportError("The required 'litellm' library is not installed.")
raise ImportError("The 'litellm' library is required. Please install it using 'pip install litellm'.")

from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
Expand Down Expand Up @@ -91,7 +82,7 @@ def generate_response(
}
if response_format:
params["response_format"] = response_format
if tools:
if tools: # TODO: Remove tools if no issues found with new memory addition logic
params["tools"] = tools
params["tool_choice"] = tool_choice

Expand Down
12 changes: 1 addition & 11 deletions mem0/llms/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,7 @@
try:
from ollama import Client
except ImportError:
user_input = input("The 'ollama' library is required. Install it now? [y/N]: ")
if user_input.lower() == 'y':
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", "ollama"])
from ollama import Client
except subprocess.CalledProcessError:
print("Failed to install 'ollama'. Please install it manually using 'pip install ollama'.")
sys.exit(1)
else:
print("The required 'ollama' library is not installed.")
sys.exit(1)
raise ImportError("The 'ollama' library is required. Please install it using 'pip install ollama'.")

from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
Expand Down
2 changes: 1 addition & 1 deletion mem0/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def generate_response(

if response_format:
params["response_format"] = response_format
if tools:
if tools: # TODO: Remove tools if no issues found with new memory addition logic
params["tools"] = tools
params["tool_choice"] = tool_choice

Expand Down
15 changes: 2 additions & 13 deletions mem0/llms/together.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,7 @@
try:
from together import Together
except ImportError:
user_input = input("The 'together' library is required. Install it now? [y/N]: ")
if user_input.lower() == 'y':
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", "together"])
from together import Together
except subprocess.CalledProcessError:
print("Failed to install 'together'. Please install it manually using 'pip install together'.")
sys.exit(1)
else:
print("The required 'together' library is not installed.")
sys.exit(1)
raise ImportError("The 'together' library is required. Please install it using 'pip install together'.")

from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
Expand All @@ -29,7 +19,6 @@ def __init__(self, config: Optional[BaseLlmConfig] = None):

if not self.config.model:
self.config.model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
self.client = Together()

api_key = self.config.api_key or os.getenv("TOGETHER_API_KEY")
self.client = Together(api_key=api_key)
Expand Down Expand Up @@ -92,7 +81,7 @@ def generate_response(
}
if response_format:
params["response_format"] = response_format
if tools:
if tools: # TODO: Remove tools if no issues found with new memory addition logic
params["tools"] = tools
params["tool_choice"] = tool_choice

Expand Down
2 changes: 2 additions & 0 deletions mem0/llms/utils/tools.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# TODO: Remove these tools if no issues are found for new memory addition logic

ADD_MEMORY_TOOL = {
"type": "function",
"function": {
Expand Down
1 change: 1 addition & 0 deletions mem0/utils/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class LlmFactory:
"litellm": "mem0.llms.litellm.LiteLLM",
"azure_openai": "mem0.llms.azure_openai.AzureOpenAILLM",
"openai_structured": "mem0.llms.openai_structured.OpenAIStructuredLLM",
"anthropic": "mem0.llms.anthropic.AnthropicLLM"
}

@classmethod
Expand Down
13 changes: 1 addition & 12 deletions mem0/vector_stores/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,7 @@
import chromadb
from chromadb.config import Settings
except ImportError:
user_input = input("The 'chromadb' library is required. Install it now? [y/N]: ")
if user_input.lower() == 'y':
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", "chromadb"])
import chromadb
from chromadb.config import Settings
except subprocess.CalledProcessError:
print("Failed to install 'chromadb'. Please install it manually using 'pip install chromadb'.")
sys.exit(1)
else:
print("The required 'chromadb' library is not installed.")
sys.exit(1)
raise ImportError("The 'chromadb' library is required. Please install it using 'pip install chromadb'.")

from mem0.vector_stores.base import VectorStoreBase

Expand Down
14 changes: 1 addition & 13 deletions mem0/vector_stores/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,7 @@
import psycopg2
from psycopg2.extras import execute_values
except ImportError:
user_input = input("The 'psycopg2' library is required. Install it now? [y/N]: ")
if user_input.lower() == 'y':
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", "psycopg2"])
import psycopg2
from psycopg2.extras import execute_values
except subprocess.CalledProcessError:
print("Failed to install 'psycopg2'. Please install it manually using 'pip install psycopg2'.")
sys.exit(1)
else:
print("The required 'psycopg2' library is not installed.")
sys.exit(1)

raise ImportError("The 'psycopg2' library is required. Please install it using 'pip install psycopg2'.")

from mem0.vector_stores.base import VectorStoreBase

Expand Down

0 comments on commit d32ae1a

Please sign in to comment.