-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3 from shubham-attri/v1
pr
- Loading branch information
Showing
570 changed files
with
1,234 additions
and
276 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified
BIN
+0 Bytes
(100%)
backend/app/agents/__pycache__/orchestrator.cpython-312.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from typing import Dict, Any | ||
from langchain_core.messages import BaseMessage | ||
from langchain_anthropic import ChatAnthropic | ||
from app.core.config import settings | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
class BaseAgent: | ||
"""Base agent class with common functionality""" | ||
|
||
def __init__(self, name: str, description: str): | ||
self.name = name | ||
self.description = description | ||
self.llm = ChatAnthropic( | ||
anthropic_api_key=settings.ANTHROPIC_API_KEY, | ||
model_name=settings.ANTHROPIC_MODEL | ||
) | ||
logger.info(f"Initialized {name} agent") | ||
|
||
async def process(self, messages: list[BaseMessage], **kwargs) -> Dict[str, Any]: | ||
"""Process messages and return response""" | ||
try: | ||
response = await self.llm.ainvoke(messages) | ||
return { | ||
"output": response.content, | ||
"agent": self.name, | ||
"metadata": kwargs | ||
} | ||
except Exception as e: | ||
logger.error(f"Error in {self.name} agent: {str(e)}") | ||
raise |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from typing import Dict, Any | ||
from langchain_core.messages import BaseMessage | ||
from langchain_anthropic import ChatAnthropic | ||
from app.core.config import settings | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
class BaseAgent: | ||
"""Base agent class with common functionality""" | ||
|
||
def __init__(self, name: str, description: str): | ||
self.name = name | ||
self.description = description | ||
self.llm = ChatAnthropic( | ||
anthropic_api_key=settings.ANTHROPIC_API_KEY, | ||
model_name=settings.ANTHROPIC_MODEL | ||
) | ||
logger.info(f"Initialized {name} agent") | ||
|
||
async def process(self, messages: list[BaseMessage], **kwargs) -> Dict[str, Any]: | ||
"""Process messages and return response""" | ||
try: | ||
response = await self.llm.ainvoke(messages) | ||
return { | ||
"output": response.content, | ||
"agent": self.name, | ||
"metadata": kwargs | ||
} | ||
except Exception as e: | ||
logger.error(f"Error in {self.name} agent: {str(e)}") | ||
raise |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from typing import Optional | ||
from langchain_core.messages import SystemMessage, HumanMessage | ||
from .base import BaseAgent | ||
|
||
class CaseAgent(BaseAgent): | ||
"""Agent for handling case-specific queries""" | ||
|
||
def __init__(self): | ||
super().__init__( | ||
name="Case Agent", | ||
description="Handles case-specific analysis and document management" | ||
) | ||
|
||
async def process_query(self, query: str, case_id: Optional[str] = None) -> dict: | ||
"""Process a case-specific query""" | ||
messages = [ | ||
SystemMessage(content="""You are a case analysis assistant. | ||
Focus on case-specific details and maintain context across interactions."""), | ||
HumanMessage(content=query) | ||
] | ||
|
||
return await self.process(messages, mode="case", case_id=case_id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from typing import Optional | ||
from langchain_core.messages import SystemMessage, HumanMessage | ||
from .base import BaseAgent | ||
|
||
class CaseAgent(BaseAgent): | ||
"""Agent for handling case-specific queries""" | ||
|
||
def __init__(self): | ||
super().__init__( | ||
name="Case Agent", | ||
description="Handles case-specific analysis and document management" | ||
) | ||
|
||
async def process_query(self, query: str, case_id: Optional[str] = None) -> dict: | ||
"""Process a case-specific query""" | ||
messages = [ | ||
SystemMessage(content="""You are a case analysis assistant. | ||
Focus on case-specific details and maintain context across interactions."""), | ||
HumanMessage(content=query) | ||
] | ||
|
||
return await self.process(messages, mode="case", case_id=case_id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
from typing import Dict, Any, List, Literal, TypedDict | ||
from langgraph.graph import END, START, StateGraph | ||
from langgraph.prebuilt import ToolNode | ||
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage | ||
from langchain_anthropic import ChatAnthropic | ||
from langfuse.callback import CallbackHandler | ||
from .tools import tools | ||
from app.core.config import settings | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
class MessagesState(TypedDict): | ||
"""State schema for messages""" | ||
messages: List[BaseMessage] | ||
mode: str | ||
current_agent: str | ||
metadata: Dict[str, Any] | ||
|
||
class AgentOrchestrator: | ||
"""Orchestrates multiple agents using LangGraph""" | ||
|
||
def __init__(self): | ||
self.tool_node = ToolNode(tools) | ||
self.langfuse = CallbackHandler( | ||
public_key=settings.LANGFUSE_PUBLIC_KEY, | ||
secret_key=settings.LANGFUSE_SECRET_KEY, | ||
host=settings.LANGFUSE_HOST | ||
) | ||
# Initialize LLM models | ||
self.research_model = ChatAnthropic( | ||
anthropic_api_key=settings.ANTHROPIC_API_KEY, | ||
model_name=settings.ANTHROPIC_MODEL | ||
).bind_tools(tools) | ||
self.case_model = ChatAnthropic( | ||
anthropic_api_key=settings.ANTHROPIC_API_KEY, | ||
model_name=settings.ANTHROPIC_MODEL | ||
).bind_tools(tools) | ||
|
||
self.graph = self._build_graph() | ||
logger.info("Initialized Agent Orchestrator") | ||
|
||
def should_continue(self, state: MessagesState) -> Literal["tools", END]: | ||
"""Determine if we should continue processing""" | ||
messages = state["messages"] | ||
last_message = messages[-1] | ||
|
||
# If LLM wants to use tools, continue | ||
if isinstance(last_message, AIMessage) and last_message.tool_calls: | ||
return "tools" | ||
return END | ||
|
||
def call_model(self, state: MessagesState) -> Dict: | ||
"""Process with appropriate agent based on mode""" | ||
try: | ||
messages = state["messages"] | ||
mode = state["mode"] | ||
|
||
# Select appropriate model | ||
model = self.research_model if mode == "research" else self.case_model | ||
|
||
# Process with model | ||
response = model.invoke(messages) | ||
return {"messages": [response]} | ||
|
||
except Exception as e: | ||
logger.error(f"Error in model call: {str(e)}") | ||
raise | ||
|
||
def _build_graph(self) -> StateGraph: | ||
"""Build the agent interaction graph""" | ||
# Create graph with state schema | ||
graph = StateGraph(MessagesState) | ||
|
||
# Add nodes | ||
graph.add_node("agent", self.call_model) | ||
graph.add_node("tools", self.tool_node) | ||
|
||
# Set entry point | ||
graph.add_edge(START, "agent") | ||
|
||
# Add conditional edges | ||
graph.add_conditional_edges( | ||
"agent", | ||
self.should_continue, | ||
{ | ||
"tools": "tools", | ||
END: END | ||
} | ||
) | ||
|
||
# Add edge from tools back to agent | ||
graph.add_edge("tools", "agent") | ||
|
||
return graph.compile() | ||
|
||
async def process_query( | ||
self, | ||
query: str, | ||
mode: str = "research", | ||
thread_id: str = None, | ||
**kwargs | ||
) -> Dict[str, Any]: | ||
"""Process query through graph""" | ||
try: | ||
# Initialize state | ||
state = { | ||
"messages": [HumanMessage(content=query)], | ||
"mode": mode, | ||
"current_agent": "", | ||
"metadata": kwargs | ||
} | ||
|
||
# Process through graph with tracing | ||
config = { | ||
"callbacks": [self.langfuse], | ||
"configurable": {"thread_id": thread_id} if thread_id else {} | ||
} | ||
|
||
result = await self.graph.ainvoke(state, config=config) | ||
return result | ||
|
||
except Exception as e: | ||
logger.error(f"Error in orchestrator: {str(e)}") | ||
raise |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
from typing import Dict, Any, List, Literal, TypedDict | ||
from langgraph.graph import END, START, StateGraph | ||
from langgraph.prebuilt import ToolNode | ||
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage | ||
from langchain_anthropic import ChatAnthropic | ||
from langfuse.callback import CallbackHandler | ||
from .tools import tools | ||
from app.core.config import settings | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
class MessagesState(TypedDict): | ||
"""State schema for messages""" | ||
messages: List[BaseMessage] | ||
mode: str | ||
current_agent: str | ||
metadata: Dict[str, Any] | ||
|
||
class AgentOrchestrator: | ||
"""Orchestrates multiple agents using LangGraph""" | ||
|
||
def __init__(self): | ||
self.tool_node = ToolNode(tools) | ||
self.langfuse = CallbackHandler( | ||
public_key=settings.LANGFUSE_PUBLIC_KEY, | ||
secret_key=settings.LANGFUSE_SECRET_KEY, | ||
host=settings.LANGFUSE_HOST | ||
) | ||
# Initialize LLM models | ||
self.research_model = ChatAnthropic( | ||
anthropic_api_key=settings.ANTHROPIC_API_KEY, | ||
model_name=settings.ANTHROPIC_MODEL | ||
).bind_tools(tools) | ||
self.case_model = ChatAnthropic( | ||
anthropic_api_key=settings.ANTHROPIC_API_KEY, | ||
model_name=settings.ANTHROPIC_MODEL | ||
).bind_tools(tools) | ||
|
||
self.graph = self._build_graph() | ||
logger.info("Initialized Agent Orchestrator") | ||
|
||
def should_continue(self, state: MessagesState) -> Literal["tools", END]: | ||
"""Determine if we should continue processing""" | ||
messages = state["messages"] | ||
last_message = messages[-1] | ||
|
||
# If LLM wants to use tools, continue | ||
if isinstance(last_message, AIMessage) and last_message.tool_calls: | ||
return "tools" | ||
return END | ||
|
||
def call_model(self, state: MessagesState) -> Dict: | ||
"""Process with appropriate agent based on mode""" | ||
try: | ||
messages = state["messages"] | ||
mode = state["mode"] | ||
|
||
# Select appropriate model | ||
model = self.research_model if mode == "research" else self.case_model | ||
|
||
# Process with model | ||
response = model.invoke(messages) | ||
return {"messages": [response]} | ||
|
||
except Exception as e: | ||
logger.error(f"Error in model call: {str(e)}") | ||
raise | ||
|
||
def _build_graph(self) -> StateGraph: | ||
"""Build the agent interaction graph""" | ||
# Create graph with state schema | ||
graph = StateGraph(MessagesState) | ||
|
||
# Add nodes | ||
graph.add_node("agent", self.call_model) | ||
graph.add_node("tools", self.tool_node) | ||
|
||
# Set entry point | ||
graph.add_edge(START, "agent") | ||
|
||
# Add conditional edges | ||
graph.add_conditional_edges( | ||
"agent", | ||
self.should_continue, | ||
{ | ||
"tools": "tools", | ||
END: END | ||
} | ||
) | ||
|
||
# Add edge from tools back to agent | ||
graph.add_edge("tools", "agent") | ||
|
||
return graph.compile() | ||
|
||
async def process_query( | ||
self, | ||
query: str, | ||
mode: str = "research", | ||
thread_id: str = None, | ||
**kwargs | ||
) -> Dict[str, Any]: | ||
"""Process query through graph""" | ||
try: | ||
# Initialize state | ||
state = { | ||
"messages": [HumanMessage(content=query)], | ||
"mode": mode, | ||
"current_agent": "", | ||
"metadata": kwargs | ||
} | ||
|
||
# Process through graph with tracing | ||
config = { | ||
"callbacks": [self.langfuse], | ||
"configurable": {"thread_id": thread_id} if thread_id else {} | ||
} | ||
|
||
result = await self.graph.ainvoke(state, config=config) | ||
return result | ||
|
||
except Exception as e: | ||
logger.error(f"Error in orchestrator: {str(e)}") | ||
raise |
Oops, something went wrong.