Skip to content

Commit

Permalink
add support for TokenUsage through callbacks (brainlid#137)
Browse files Browse the repository at this point in the history
* support TokenUsage callbacks
- added OpenAI.stream_options support
- support TokenUsage on streamed and not streamed

* upgraded req

* updates for req upgrade

* updated token usage for anthropic behavior

* updates
- support skipping a parsed message
- anthropic updates for updated req version
- fire token usage callback

* updated ChatBumblebee
- for new callbacks
- added support TokenUsage

* fixes for token usage on non-streamed messages
  • Loading branch information
brainlid authored Jun 12, 2024
1 parent 6c01041 commit 9fc74c9
Show file tree
Hide file tree
Showing 18 changed files with 606 additions and 182 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* `LangChain.MessageProcessors.JsonProcessor` is capable of extracting JSON contents and converting it to an Elixir map using `Jason`. Parsing errors are returned to the LLM for it to try again.
* The attribute `processed_content` was added to a `LangChain.Message`. When a MessageProcessor is run on a received assistant message, the results of the processing are accumulated there. The original `content` remains unchanged for when it is sent back to the LLM and used when fixing or correcting it's generated content.
* Callback support for LLM ratelimit information returned in API response headers. These are currently implemented for Anthropic and OpenAI.
* Callback support for LLM token usage information returned when available.

**Changed:**

Expand Down
65 changes: 48 additions & 17 deletions lib/chat_models/chat_anthropic.ex
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ defmodule LangChain.ChatModels.ChatAnthropic do
alias LangChain.Message.ToolCall
alias LangChain.Message.ToolResult
alias LangChain.MessageDelta
alias LangChain.TokenUsage
alias LangChain.Function
alias LangChain.FunctionParam
alias LangChain.Utils
Expand Down Expand Up @@ -339,7 +340,12 @@ defmodule LangChain.ChatModels.ChatAnthropic do
get_ratelimit_info(response.headers)
])

case do_process_response(data) do
Callbacks.fire(anthropic.callbacks, :on_llm_token_usage, [
anthropic,
get_token_usage(data)
])

case do_process_response(anthropic, data) do
{:error, reason} ->
{:error, reason}

Expand All @@ -348,10 +354,10 @@ defmodule LangChain.ChatModels.ChatAnthropic do
result
end

{:error, %Mint.TransportError{reason: :timeout}} ->
{:error, %Req.TransportError{reason: :timeout}} ->
{:error, "Request timed out"}

{:error, %Mint.TransportError{reason: :closed}} ->
{:error, %Req.TransportError{reason: :closed}} ->
# Force a retry by making a recursive call decrementing the counter
Logger.debug(fn -> "Mint connection closed: retry count = #{inspect(retry_count)}" end)
do_api_request(anthropic, messages, tools, retry_count - 1)
Expand All @@ -374,7 +380,10 @@ defmodule LangChain.ChatModels.ChatAnthropic do
headers: headers(get_api_key(anthropic), anthropic.api_version),
receive_timeout: anthropic.receive_timeout
)
|> Req.post(into: Utils.handle_stream_fn(anthropic, &decode_stream/1, &do_process_response/1))
|> Req.post(
into:
Utils.handle_stream_fn(anthropic, &decode_stream/1, &do_process_response(anthropic, &1))
)
|> case do
{:ok, %Req.Response{body: data} = response} ->
Callbacks.fire(anthropic.callbacks, :on_llm_ratelimit_info, [
Expand All @@ -387,10 +396,10 @@ defmodule LangChain.ChatModels.ChatAnthropic do
{:error, %LangChainError{message: reason}} ->
{:error, reason}

{:error, %Mint.TransportError{reason: :timeout}} ->
{:error, %Req.TransportError{reason: :timeout}} ->
{:error, "Request timed out"}

{:error, %Mint.TransportError{reason: :closed}} ->
{:error, %Req.TransportError{reason: :closed}} ->
# Force a retry by making a recursive call decrementing the counter
Logger.debug(fn -> "Mint connection closed: retry count = #{inspect(retry_count)}" end)
do_api_request(anthropic, messages, tools, retry_count - 1)
Expand All @@ -416,13 +425,13 @@ defmodule LangChain.ChatModels.ChatAnthropic do

# Parse a new message response
@doc false
@spec do_process_response(data :: %{String.t() => any()} | {:error, any()}) ::
@spec do_process_response(t(), data :: %{String.t() => any()} | {:error, any()}) ::
Message.t()
| [Message.t()]
| MessageDelta.t()
| [MessageDelta.t()]
| {:error, String.t()}
def do_process_response(%{
def do_process_response(_model, %{
"role" => "assistant",
"content" => contents,
"stop_reason" => stop_reason
Expand All @@ -441,7 +450,7 @@ defmodule LangChain.ChatModels.ChatAnthropic do
end)
end

def do_process_response(%{
def do_process_response(_model, %{
"type" => "content_block_start",
"content_block" => %{"type" => "text", "text" => content}
}) do
Expand All @@ -454,7 +463,7 @@ defmodule LangChain.ChatModels.ChatAnthropic do
|> to_response()
end

def do_process_response(%{
def do_process_response(_model, %{
"type" => "content_block_delta",
"delta" => %{"type" => "text_delta", "text" => content}
}) do
Expand All @@ -467,10 +476,17 @@ defmodule LangChain.ChatModels.ChatAnthropic do
|> to_response()
end

def do_process_response(%{
"type" => "message_delta",
"delta" => %{"stop_reason" => stop_reason}
}) do
def do_process_response(
model,
%{
"type" => "message_delta",
"delta" => %{"stop_reason" => stop_reason},
"usage" => _usage
} = data
) do
# if we received usage data, fire any callbacks for it.
Callbacks.fire(model.callbacks, :on_llm_token_usage, [model, get_token_usage(data)])

%{
role: :assistant,
content: "",
Expand All @@ -480,18 +496,18 @@ defmodule LangChain.ChatModels.ChatAnthropic do
|> to_response()
end

def do_process_response(%{"error" => %{"message" => reason}}) do
def do_process_response(_model, %{"error" => %{"message" => reason}}) do
Logger.error("Received error from API: #{inspect(reason)}")
{:error, reason}
end

def do_process_response({:error, %Jason.DecodeError{} = response}) do
def do_process_response(_model, {:error, %Jason.DecodeError{} = response}) do
error_message = "Received invalid JSON: #{inspect(response)}"
Logger.error(error_message)
{:error, error_message}
end

def do_process_response(other) do
def do_process_response(_model, other) do
Logger.error("Trying to process an unexpected response. #{inspect(other)}")
{:error, "Unexpected response"}
end
Expand Down Expand Up @@ -835,4 +851,19 @@ defmodule LangChain.ChatModels.ChatAnthropic do

return
end

defp get_token_usage(%{"usage" => usage} = _response_body) do
# extract out the reported response token usage
#
# defp get_token_usage(%{"usage" => usage} = _response_body) do
# extract out the reported response token usage
#
# https://platform.openai.com/docs/api-reference/chat/object#chat/object-usage
TokenUsage.new!(%{
input: Map.get(usage, "input_tokens"),
output: Map.get(usage, "output_tokens")
})
end

defp get_token_usage(_response_body), do: %{}
end
67 changes: 43 additions & 24 deletions lib/chat_models/chat_bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,13 @@ defmodule LangChain.ChatModels.ChatBumblebee do
alias __MODULE__
alias LangChain.ChatModels.ChatModel
alias LangChain.Message
alias LangChain.Function
alias LangChain.TokenUsage
alias LangChain.LangChainError
alias LangChain.Utils
alias LangChain.MessageDelta
alias LangChain.Utils.ChatTemplates
alias LangChain.Callbacks

@behaviour ChatModel

Expand All @@ -131,6 +134,9 @@ defmodule LangChain.ChatModels.ChatBumblebee do
# Seed for randomizing behavior or giving more deterministic output. Helpful
# for testing.
field :seed, :integer, default: nil

# A list of maps for callback handlers
field :callbacks, {:array, :map}, default: []
end

@type t :: %ChatBumblebee{}
Expand All @@ -146,7 +152,8 @@ defmodule LangChain.ChatModels.ChatBumblebee do
# :temperature,
:seed,
:template_format,
:stream
:stream,
:callbacks
]
@required_fields [:serving]

Expand Down Expand Up @@ -181,27 +188,27 @@ defmodule LangChain.ChatModels.ChatBumblebee do
end

@impl ChatModel
def call(model, prompt, functions \\ [], callback_fn \\ nil)
def call(model, prompt, functions \\ [])

def call(%ChatBumblebee{} = model, prompt, functions, callback_fn) when is_binary(prompt) do
def call(%ChatBumblebee{} = model, prompt, functions) when is_binary(prompt) do
messages = [
Message.new_system!(),
Message.new_user!(prompt)
]

call(model, messages, functions, callback_fn)
call(model, messages, functions)
end

def call(%ChatBumblebee{} = model, messages, functions, callback_fn)
def call(%ChatBumblebee{} = model, messages, functions)
when is_list(messages) do
if override_api_return?() do
Logger.warning("Found override API response. Will not make live API call.")

# fire callback for fake responses too
case get_api_override() do
{:ok, {:ok, data} = response} ->
# fire callback for fake responses too
Utils.fire_callback(model, data, callback_fn)
response
{:ok, {:ok, data, callback_name}} ->
Callbacks.fire(model.callbacks, callback_name, [model, data])
{:ok, data}

_other ->
raise LangChainError,
Expand All @@ -210,7 +217,7 @@ defmodule LangChain.ChatModels.ChatBumblebee do
else
try do
# make base api request and perform high-level success/failure checks
case do_serving_request(model, messages, functions, callback_fn) do
case do_serving_request(model, messages, functions) do
{:error, reason} ->
{:error, reason}

Expand All @@ -225,27 +232,28 @@ defmodule LangChain.ChatModels.ChatBumblebee do
end

@doc false
@spec do_serving_request(t(), [Message.t()], [Function.t()], callback_fn()) ::
@spec do_serving_request(t(), [Message.t()], [Function.t()]) ::
list() | struct() | {:error, String.t()}
def do_serving_request(%ChatBumblebee{} = model, messages, _functions, callback_fn) do
def do_serving_request(%ChatBumblebee{} = model, messages, _functions) do
prompt = ChatTemplates.apply_chat_template!(messages, model.template_format)

model.serving
|> Nx.Serving.batched_run(%{text: prompt, seed: model.seed})
|> do_process_response(model, callback_fn)
|> do_process_response(model)
end

@doc false
def do_process_response(
%{results: [%{text: content, token_summary: _token_summary}]},
%ChatBumblebee{} = model,
callback_fn
%{results: [%{text: content, token_summary: token_summary}]},
%ChatBumblebee{} = model
)
when is_binary(content) do
fire_token_usage_callback(model, token_summary)

case Message.new(%{role: :assistant, status: :complete, content: content}) do
{:ok, message} ->
# execute the callback with the final message
Utils.fire_callback(model, [message], callback_fn)
Callbacks.fire(model.callbacks, :on_llm_new_message, [model, message])
# return a list of the complete message. As a list for compatibility.
[message]

Expand All @@ -256,32 +264,34 @@ defmodule LangChain.ChatModels.ChatBumblebee do
end
end

def do_process_response(stream, %ChatBumblebee{stream: false} = model, callback_fn) do
def do_process_response(stream, %ChatBumblebee{stream: false} = model) do
# Request is to NOT stream. Consume the full stream and format the data as
# though it had not been streamed.
full_data =
Enum.reduce(stream, %{text: "", token_summary: nil}, fn
{:done, token_data}, %{text: text} ->
{:done, %{token_summary: token_data}}, %{text: text} ->
%{text: text, token_summary: token_data}

data, %{text: text} = acc ->
Map.put(acc, :text, text <> data)
end)

do_process_response(%{results: [full_data]}, model, callback_fn)
do_process_response(%{results: [full_data]}, model)
end

def do_process_response(stream, %ChatBumblebee{} = model, callback_fn) do
def do_process_response(stream, %ChatBumblebee{} = model) do
chunk_processor = fn
{:done, _token_data} ->
{:done, %{token_summary: token_summary}} ->
fire_token_usage_callback(model, token_summary)

final_delta = MessageDelta.new!(%{role: :assistant, status: :complete})
Utils.fire_callback(model, [final_delta], callback_fn)
Callbacks.fire(model.callbacks, :on_llm_new_delta, [model, final_delta])
final_delta

content when is_binary(content) ->
case MessageDelta.new(%{content: content, role: :assistant, status: :incomplete}) do
{:ok, delta} ->
Utils.fire_callback(model, [delta], callback_fn)
Callbacks.fire(model.callbacks, :on_llm_new_delta, [model, delta])
delta

{:error, changeset} ->
Expand All @@ -303,4 +313,13 @@ defmodule LangChain.ChatModels.ChatBumblebee do
# return a list of a list to mirror the way ChatGPT returns data
[result]
end

defp fire_token_usage_callback(model, %{input: input, output: output} = _token_summary) do
Callbacks.fire(model.callbacks, :on_llm_token_usage, [
model,
TokenUsage.new!(%{input: input, output: output})
])
end

defp fire_token_usage_callback(_model, _token_summary), do: :ok
end
4 changes: 2 additions & 2 deletions lib/chat_models/chat_google_ai.ex
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ defmodule LangChain.ChatModels.ChatGoogleAI do
result
end

{:error, %Mint.TransportError{reason: :timeout}} ->
{:error, %Req.TransportError{reason: :timeout}} ->
{:error, "Request timed out"}

other ->
Expand Down Expand Up @@ -323,7 +323,7 @@ defmodule LangChain.ChatModels.ChatGoogleAI do
{:error, %LangChainError{message: reason}} ->
{:error, reason}

{:error, %Mint.TransportError{reason: :timeout}} ->
{:error, %Req.TransportError{reason: :timeout}} ->
{:error, "Request timed out"}

other ->
Expand Down
8 changes: 4 additions & 4 deletions lib/chat_models/chat_mistral_ai.ex
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,10 @@ defmodule Langchain.ChatModels.ChatMistralAI do
result
end

{:error, %Mint.TransportError{reason: :timeout}} ->
{:error, %Req.TransportError{reason: :timeout}} ->
{:error, "Request timed out"}

{:error, %Mint.TransportError{reason: :closed}} ->
{:error, %Req.TransportError{reason: :closed}} ->
# Force a retry by making a recursive call decrementing the counter
Logger.debug(fn -> "Mint connection closed: retry count = #{inspect(retry_count)}" end)
do_api_request(mistral, messages, functions, callback_fn, retry_count - 1)
Expand Down Expand Up @@ -256,10 +256,10 @@ defmodule Langchain.ChatModels.ChatMistralAI do
{:error, %LangChainError{message: reason}} ->
{:error, reason}

{:error, %Mint.TransportError{reason: :timeout}} ->
{:error, %Req.TransportError{reason: :timeout}} ->
{:error, "Request timed out"}

{:error, %Mint.TransportError{reason: :closed}} ->
{:error, %Req.TransportError{reason: :closed}} ->
# Force a retry by making a recursive call decrementing the counter
Logger.debug(fn -> "Mint connection closed: retry count = #{inspect(retry_count)}" end)
do_api_request(mistral, messages, functions, callback_fn, retry_count - 1)
Expand Down
Loading

0 comments on commit 9fc74c9

Please sign in to comment.