Skip to content

Commit

Permalink
Merge pull request stitionai#112 from ketangangal/feature-dev/token-u…
Browse files Browse the repository at this point in the history
…sage-from-agent-state

Token Usage should be inferred from AgentState instead of Lazy initialized global value (TOKEN_USAGE) stitionai#10
  • Loading branch information
mufeedvh authored Mar 26, 2024
2 parents 79b41cb + 70641bf commit c00a390
Show file tree
Hide file tree
Showing 17 changed files with 78 additions and 51 deletions.
5 changes: 3 additions & 2 deletions devika.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,9 @@ def calculate_tokens():
@app.route("/api/token-usage", methods=["GET"])
@route_logger(logger)
def token_usage():
from src.llm import TOKEN_USAGE
return jsonify({"token_usage": TOKEN_USAGE})
project_name = request.args.get("project_name")
token_count = AgentState().get_latest_token_usage(project_name)
return jsonify({"token_usage": token_count})


@app.route("/api/real-time-logs", methods=["GET"])
Expand Down
6 changes: 3 additions & 3 deletions src/agents/action/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ def validate_response(self, response: str):
else:
return response["response"], response["action"]

def execute(self, conversation: list) -> str:
def execute(self, conversation: list, project_name: str) -> str:
prompt = self.render(conversation)
response = self.llm.inference(prompt)
response = self.llm.inference(prompt, project_name)

valid_response = self.validate_response(response)

while not valid_response:
print("Invalid response from the model, trying again...")
return self.execute(conversation)
return self.execute(conversation, project_name)

print("===" * 10)
print(valid_response)
Expand Down
24 changes: 13 additions & 11 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def search_queries(self, queries: list, project_name: str) -> dict:
Formatter Agent is invoked to format and learn from the contents
"""
results[query] = self.formatter.execute(
browser.extract_text()
browser.extract_text(),
project_name
)

"""
Expand All @@ -118,7 +119,7 @@ def update_contextual_keywords(self, sentence: str):
Decision making Agent
"""
def make_decision(self, prompt: str, project_name: str) -> str:
decision = self.decision.execute(prompt)
decision = self.decision.execute(prompt, project_name)

for item in decision:
function = item["function"]
Expand All @@ -134,7 +135,7 @@ def make_decision(self, prompt: str, project_name: str) -> str:
elif function == "generate_pdf_document":
user_prompt = args["user_prompt"]
# Call the reporter agent to generate the PDF document
markdown = self.reporter.execute([user_prompt], "")
markdown = self.reporter.execute([user_prompt], "", project_name)
_out_pdf_file = PDF().markdown_to_pdf(markdown, project_name)

project_name_space_url = project_name.replace(" ", "%20")
Expand All @@ -154,10 +155,10 @@ def make_decision(self, prompt: str, project_name: str) -> str:
elif function == "coding_project":
user_prompt = args["user_prompt"]
# Call the planner, researcher, coder agents in sequence
plan = self.planner.execute(user_prompt)
plan = self.planner.execute(user_prompt, project_name)
planner_response = self.planner.parse_response(plan)

research = self.researcher.execute(plan, self.collected_context_keywords)
research = self.researcher.execute(plan, self.collected_context_keywords, project_name)
search_results = self.search_queries(research["queries"], project_name)

code = self.coder.execute(
Expand All @@ -177,7 +178,7 @@ def subsequent_execute(self, prompt: str, project_name: str) -> str:
conversation = ProjectManager().get_all_messages_formatted(project_name)
code_markdown = ReadCode(project_name).code_set_to_markdown()

response, action = self.action.execute(conversation)
response, action = self.action.execute(conversation, project_name)

ProjectManager().add_message_from_devika(project_name, response)

Expand All @@ -188,7 +189,8 @@ def subsequent_execute(self, prompt: str, project_name: str) -> str:
if action == "answer":
response = self.answer.execute(
conversation=conversation,
code_markdown=code_markdown
code_markdown=code_markdown,
project_name=project_name
)
ProjectManager().add_message_from_devika(project_name, response)
elif action == "run":
Expand Down Expand Up @@ -238,7 +240,7 @@ def subsequent_execute(self, prompt: str, project_name: str) -> str:

self.patcher.save_code_to_project(code, project_name)
elif action == "report":
markdown = self.reporter.execute(conversation, code_markdown)
markdown = self.reporter.execute(conversation, code_markdown, project_name)

_out_pdf_file = PDF().markdown_to_pdf(markdown, project_name)

Expand All @@ -261,7 +263,7 @@ def execute(self, prompt: str, project_name_from_user: str = None) -> str:
if project_name_from_user:
ProjectManager().add_message_from_user(project_name_from_user, prompt)

plan = self.planner.execute(prompt)
plan = self.planner.execute(prompt, project_name_from_user)
print(plan)
print("=====" * 10)

Expand All @@ -288,15 +290,15 @@ def execute(self, prompt: str, project_name_from_user: str = None) -> str:
self.update_contextual_keywords(focus)
print(self.collected_context_keywords)

internal_monologue = self.internal_monologue.execute(current_prompt=plan)
internal_monologue = self.internal_monologue.execute(current_prompt=plan, project_name=project_name)
print(internal_monologue)
print("=====" * 10)

new_state = AgentState().new_state()
new_state["internal_monologue"] = internal_monologue
AgentState().add_to_current_state(project_name, new_state)

research = self.researcher.execute(plan, self.collected_context_keywords)
research = self.researcher.execute(plan, self.collected_context_keywords, project_name)
print(research)
print("=====" * 10)

Expand Down
6 changes: 3 additions & 3 deletions src/agents/answer/answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ def validate_response(self, response: str):
else:
return response["response"]

def execute(self, conversation: list, code_markdown: str) -> str:
def execute(self, conversation: list, code_markdown: str, project_name: str) -> str:
prompt = self.render(conversation, code_markdown)
response = self.llm.inference(prompt)
response = self.llm.inference(prompt, project_name)

valid_response = self.validate_response(response)

while not valid_response:
print("Invalid response from the model, trying again...")
return self.execute(conversation, code_markdown)
return self.execute(conversation, code_markdown, project_name)

return valid_response
4 changes: 2 additions & 2 deletions src/agents/coder/coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,13 @@ def execute(
project_name: str
) -> str:
prompt = self.render(step_by_step_plan, user_context, search_results)
response = self.llm.inference(prompt)
response = self.llm.inference(prompt, project_name)

valid_response = self.validate_response(response)

while not valid_response:
print("Invalid response from the model, trying again...")
return self.execute(step_by_step_plan, user_context, search_results)
return self.execute(step_by_step_plan, user_context, search_results, project_name)

print(valid_response)

Expand Down
6 changes: 3 additions & 3 deletions src/agents/decision/decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ def validate_response(self, response: str):

return response

def execute(self, prompt: str) -> str:
def execute(self, prompt: str, project_name: str) -> str:
prompt = self.render(prompt)
response = self.llm.inference(prompt)
response = self.llm.inference(prompt, project_name)

valid_response = self.validate_response(response)

while not valid_response:
print("Invalid response from the model, trying again...")
return self.execute(prompt)
return self.execute(prompt, project_name)

return valid_response
2 changes: 1 addition & 1 deletion src/agents/feature/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def execute(
project_name: str
) -> str:
prompt = self.render(conversation, code_markdown, system_os)
response = self.llm.inference(prompt)
response = self.llm.inference(prompt, project_name)

valid_response = self.validate_response(response)

Expand Down
4 changes: 2 additions & 2 deletions src/agents/formatter/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def render(self, raw_text: str) -> str:
def validate_response(self, response: str) -> bool:
return True

def execute(self, raw_text: str) -> str:
def execute(self, raw_text: str, project_name: str) -> str:
raw_text = self.render(raw_text)
response = self.llm.inference(raw_text)
response = self.llm.inference(raw_text, project_name)
return response
6 changes: 3 additions & 3 deletions src/agents/internal_monologue/internal_monologue.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ def validate_response(self, response: str):
else:
return response["internal_monologue"]

def execute(self, current_prompt: str) -> str:
def execute(self, current_prompt: str, project_name: str) -> str:
current_prompt = self.render(current_prompt)
response = self.llm.inference(current_prompt)
response = self.llm.inference(current_prompt, project_name)

valid_response = self.validate_response(response)

while not valid_response:
print("Invalid response from the model, trying again...")
return self.execute(current_prompt)
return self.execute(current_prompt, project_name)

return valid_response

2 changes: 1 addition & 1 deletion src/agents/patcher/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def execute(
error,
system_os
)
response = self.llm.inference(prompt)
response = self.llm.inference(prompt, project_name)

valid_response = self.validate_response(response)

Expand Down
4 changes: 2 additions & 2 deletions src/agents/planner/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def parse_response(self, response: str):

return result

def execute(self, prompt: str) -> str:
def execute(self, prompt: str, project_name: str) -> str:
prompt = self.render(prompt)
response = self.llm.inference(prompt)
response = self.llm.inference(prompt, project_name)
return response
7 changes: 4 additions & 3 deletions src/agents/reporter/reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,17 @@ def validate_response(self, response: str):

def execute(self,
conversation: list,
code_markdown: str
code_markdown: str,
project_name: str
) -> str:
prompt = self.render(conversation, code_markdown)
response = self.llm.inference(prompt)
response = self.llm.inference(prompt, project_name)

valid_response = self.validate_response(response)

while not valid_response:
print("Invalid response from the model, trying again...")
return self.execute(conversation, code_markdown)
return self.execute(conversation, code_markdown, project_name)

return valid_response

6 changes: 3 additions & 3 deletions src/agents/researcher/researcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,16 @@ def validate_response(self, response: str):
"ask_user": response["ask_user"]
}

def execute(self, step_by_step_plan: str, contextual_keywords: List[str]) -> str:
def execute(self, step_by_step_plan: str, contextual_keywords: List[str], project_name: str) -> str:
contextual_keywords = ", ".join(map(lambda k: k.capitalize(), contextual_keywords))
step_by_step_plan = self.render(step_by_step_plan, contextual_keywords)

response = self.llm.inference(step_by_step_plan)
response = self.llm.inference(step_by_step_plan, project_name)

valid_response = self.validate_response(response)

while not valid_response:
print("Invalid response from the model, trying again...")
return self.execute(step_by_step_plan, contextual_keywords)
return self.execute(step_by_step_plan, contextual_keywords, project_name)

return valid_response
4 changes: 2 additions & 2 deletions src/agents/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def run_code(
error=command_output
)

response = self.llm.inference(prompt)
response = self.llm.inference(prompt, project_name)

valid_response = self.validate_rerunner_response(response)

Expand Down Expand Up @@ -233,7 +233,7 @@ def execute(
project_name: str
) -> str:
prompt = self.render(conversation, code_markdown, os_system)
response = self.llm.inference(prompt)
response = self.llm.inference(prompt, project_name)

valid_response = self.validate_response(response)

Expand Down
2 changes: 1 addition & 1 deletion src/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .llm import LLM, TOKEN_USAGE
from .llm import LLM
16 changes: 8 additions & 8 deletions src/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from .openai_client import OpenAI
from .groq_client import Groq

from src.state import AgentState

import tiktoken

TOKEN_USAGE = 0
TIKTOKEN_ENC = tiktoken.get_encoding("cl100k_base")

class Model(Enum):
Expand Down Expand Up @@ -40,15 +41,14 @@ def model_id_to_enum_mapping(self):
models.update(ollama_models)
return models

def update_global_token_usage(self, string: str):
global TOKEN_USAGE
TOKEN_USAGE += len(TIKTOKEN_ENC.encode(string))
print(f"Token usage: {TOKEN_USAGE}")
def update_global_token_usage(self, string: str, project_name: str):
token_usage = len(TIKTOKEN_ENC.encode(string))
AgentState().update_token_usage(project_name, token_usage)

def inference(
self, prompt: str
self, prompt: str, project_name: str
) -> str:
self.update_global_token_usage(prompt)
self.update_global_token_usage(prompt, project_name)

model = self.model_id_to_enum_mapping()[self.model_id]

Expand All @@ -63,6 +63,6 @@ def inference(
else:
raise ValueError(f"Model {model} not supported")

self.update_global_token_usage(response)
self.update_global_token_usage(response, project_name)

return response
25 changes: 24 additions & 1 deletion src/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,27 @@ def is_agent_completed(self, project: str):
agent_state = session.query(AgentStateModel).filter(AgentStateModel.project == project).first()
if agent_state:
return json.loads(agent_state.state_stack_json)[-1]["completed"]
return None
return None

def update_token_usage(self, project: str, token_usage: int):
with Session(self.engine) as session:
agent_state = session.query(AgentStateModel).filter(AgentStateModel.project == project).first()
print(agent_state)
if agent_state:
state_stack = json.loads(agent_state.state_stack_json)
state_stack[-1]["token_usage"] += token_usage
agent_state.state_stack_json = json.dumps(state_stack)
session.commit()
else:
state_stack = [self.new_state()]
state_stack[-1]["token_usage"] = token_usage
agent_state = AgentStateModel(project=project, state_stack_json=json.dumps(state_stack))
session.add(agent_state)
session.commit()

def get_latest_token_usage(self, project: str):
with Session(self.engine) as session:
agent_state = session.query(AgentStateModel).filter(AgentStateModel.project == project).first()
if agent_state:
return json.loads(agent_state.state_stack_json)[-1]["token_usage"]
return 0

0 comments on commit c00a390

Please sign in to comment.