diff --git a/interpreter/terminal_interface/magic_commands.py b/interpreter/terminal_interface/magic_commands.py index 8047b19743..d884222fde 100644 --- a/interpreter/terminal_interface/magic_commands.py +++ b/interpreter/terminal_interface/magic_commands.py @@ -103,13 +103,14 @@ def handle_load_message(self, json_path): display_markdown_message(f"> messages json loaded from {os.path.abspath(json_path)}") def handle_count_tokens(self, arguments): + messages = [{"role": "system", "message": self.system_message}] + self.messages + if len(self.messages) == 0: - (tokens, cost) = count_messages_tokens(messages=[self.system_message], model=self.model) + (tokens, cost) = count_messages_tokens(messages=messages, model=self.model) display_markdown_message(f"> System Prompt Tokens: {tokens} (${cost})") else: - messages_including_system = [self.system_message] + self.messages - (tokens, cost) = count_messages_tokens(messages=messages_including_system, model=self.model) - display_markdown_message(f"> Tokens in Current Conversation: {tokens} (${cost})") + (tokens, cost) = count_messages_tokens(messages=messages, model=self.model) + display_markdown_message(f"> Conversation Tokens: {tokens} (${cost})") def handle_magic_command(self, user_input): # split the command into the command and the arguments, by the first whitespace diff --git a/interpreter/utils/count_tokens.py b/interpreter/utils/count_tokens.py index 0130d14031..bda66a325b 100644 --- a/interpreter/utils/count_tokens.py +++ b/interpreter/utils/count_tokens.py @@ -17,7 +17,7 @@ def token_cost(tokens=0, model="gpt-4"): (prompt_cost, _) = cost_per_token(model=model, prompt_tokens=tokens) - return prompt_cost + return round(prompt_cost, 6) def count_messages_tokens(messages=[], model=None): """ @@ -32,6 +32,12 @@ def count_messages_tokens(messages=[], model=None): elif "message" in message: tokens_used += count_tokens(message["message"], model=model) + if "code" in message: + tokens_used += count_tokens(message["code"], model=model) + + if "output" in message: + tokens_used += count_tokens(message["output"], model=model) + prompt_cost = token_cost(tokens_used, model=model) return (tokens_used, prompt_cost)