Skip to content

Commit

Permalink
groq[minor]: Implement streaming tool calls (langchain-ai#6203)
Browse files Browse the repository at this point in the history
* implemented and added test

* chore: lint files

* ayrn

* chore: lint files

* ensure name/id fields are only yielded once for streaming tool calls
  • Loading branch information
bracesproul authored Jul 25, 2024
1 parent 3e18ab8 commit a614d1d
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 141 deletions.
4 changes: 2 additions & 2 deletions libs/langchain-groq/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
"author": "LangChain",
"license": "MIT",
"dependencies": {
"@langchain/core": ">=0.2.16 <0.3.0",
"@langchain/core": ">=0.2.18 <0.3.0",
"@langchain/openai": "~0.2.4",
"groq-sdk": "^0.3.2",
"groq-sdk": "^0.5.0",
"zod": "^3.22.4",
"zod-to-json-schema": "^3.22.5"
},
Expand Down
217 changes: 136 additions & 81 deletions libs/langchain-groq/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import {
LangSmithParams,
type BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import * as ChatCompletionsAPI from "groq-sdk/resources/chat/completions";
import * as CompletionsAPI from "groq-sdk/resources/completions";
import {
AIMessage,
AIMessageChunk,
Expand All @@ -19,6 +21,7 @@ import {
ToolMessage,
OpenAIToolCall,
isAIMessage,
BaseMessageChunk,
} from "@langchain/core/messages";
import {
ChatGeneration,
Expand All @@ -32,7 +35,6 @@ import {
} from "@langchain/openai";
import { isZodSchema } from "@langchain/core/utils/types";
import Groq from "groq-sdk";
import { ChatCompletionChunk } from "groq-sdk/lib/chat_completions_ext";
import {
ChatCompletion,
ChatCompletionCreateParams,
Expand Down Expand Up @@ -146,8 +148,8 @@ export function messageToGroqRole(message: BaseMessage): GroqRoleEnum {

function convertMessagesToGroqParams(
messages: BaseMessage[]
): Array<ChatCompletion.Choice.Message> {
return messages.map((message): ChatCompletion.Choice.Message => {
): Array<ChatCompletionsAPI.ChatCompletionMessage> {
return messages.map((message): ChatCompletionsAPI.ChatCompletionMessage => {
if (typeof message.content !== "string") {
throw new Error("Non string message content not supported");
}
Expand All @@ -172,12 +174,12 @@ function convertMessagesToGroqParams(
completionParam.tool_call_id = (message as ToolMessage).tool_call_id;
}
}
return completionParam as ChatCompletion.Choice.Message;
return completionParam as ChatCompletionsAPI.ChatCompletionMessage;
});
}

function groqResponseToChatMessage(
message: ChatCompletion.Choice.Message
message: ChatCompletionsAPI.ChatCompletionMessage
): BaseMessage {
const rawToolCalls: OpenAIToolCall[] | undefined = message.tool_calls as
| OpenAIToolCall[]
Expand Down Expand Up @@ -206,10 +208,34 @@ function groqResponseToChatMessage(
}
}

function _convertDeltaToolCallToToolCallChunk(
toolCalls?: ChatCompletionsAPI.ChatCompletionChunk.Choice.Delta.ToolCall[],
index?: number
): ToolCallChunk[] | undefined {
if (!toolCalls?.length) return undefined;

return toolCalls.map((tc) => ({
id: tc.id,
name: tc.function?.name,
args: tc.function?.arguments,
type: "tool_call_chunk",
index,
}));
}

function _convertDeltaToMessageChunk(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
delta: Record<string, any>
) {
delta: Record<string, any>,
index: number
): {
message: BaseMessageChunk;
toolCallData?: {
id: string;
name: string;
index: number;
type: "tool_call_chunk";
}[];
} {
const { role } = delta;
const content = delta.content ?? "";
let additional_kwargs;
Expand All @@ -225,13 +251,43 @@ function _convertDeltaToMessageChunk(
additional_kwargs = {};
}
if (role === "user") {
return new HumanMessageChunk({ content });
return {
message: new HumanMessageChunk({ content }),
};
} else if (role === "assistant") {
return new AIMessageChunk({ content, additional_kwargs });
const toolCallChunks = _convertDeltaToolCallToToolCallChunk(
delta.tool_calls,
index
);
return {
message: new AIMessageChunk({
content,
additional_kwargs,
tool_call_chunks: toolCallChunks
? toolCallChunks.map((tc) => ({
type: tc.type,
args: tc.args,
index: tc.index,
}))
: undefined,
}),
toolCallData: toolCallChunks
? toolCallChunks.map((tc) => ({
id: tc.id ?? "",
name: tc.name ?? "",
index: tc.index ?? index,
type: "tool_call_chunk",
}))
: undefined,
};
} else if (role === "system") {
return new SystemMessageChunk({ content });
return {
message: new SystemMessageChunk({ content }),
};
} else {
return new ChatMessageChunk({ content, role });
return {
message: new ChatMessageChunk({ content, role }),
};
}
}

Expand Down Expand Up @@ -322,16 +378,16 @@ export class ChatGroq extends BaseChatModel<
ls_provider: "groq",
ls_model_name: this.model,
ls_model_type: "chat",
ls_temperature: params.temperature,
ls_max_tokens: params.max_tokens,
ls_temperature: params.temperature ?? this.temperature,
ls_max_tokens: params.max_tokens ?? this.maxTokens,
ls_stop: options.stop,
};
}

async completionWithRetry(
request: ChatCompletionCreateParamsStreaming,
options?: OpenAICoreRequestOptions
): Promise<AsyncIterable<ChatCompletionChunk>>;
): Promise<AsyncIterable<ChatCompletionsAPI.ChatCompletionChunk>>;

async completionWithRetry(
request: ChatCompletionCreateParamsNonStreaming,
Expand All @@ -341,7 +397,9 @@ export class ChatGroq extends BaseChatModel<
async completionWithRetry(
request: ChatCompletionCreateParams,
options?: OpenAICoreRequestOptions
): Promise<AsyncIterable<ChatCompletionChunk> | ChatCompletion> {
): Promise<
AsyncIterable<ChatCompletionsAPI.ChatCompletionChunk> | ChatCompletion
> {
return this.caller.call(async () =>
this.client.chat.completions.create(request, options)
);
Expand Down Expand Up @@ -391,76 +449,73 @@ export class ChatGroq extends BaseChatModel<
): AsyncGenerator<ChatGenerationChunk> {
const params = this.invocationParams(options);
const messagesMapped = convertMessagesToGroqParams(messages);
if (options.tools !== undefined && options.tools.length > 0) {
const result = await this._generateNonStreaming(
messages,
options,
runManager
);
const generationMessage = result.generations[0].message as AIMessage;
if (
generationMessage === undefined ||
typeof generationMessage.content !== "string"
) {
throw new Error("Could not parse Groq output.");
const response = await this.completionWithRetry(
{
...params,
messages: messagesMapped,
stream: true,
},
{
signal: options?.signal,
headers: options?.headers,
}
const toolCallChunks: ToolCallChunk[] | undefined =
generationMessage.tool_calls?.map((toolCall, i) => ({
name: toolCall.name,
args: JSON.stringify(toolCall.args),
id: toolCall.id,
index: i,
type: "tool_call_chunk",
}));
yield new ChatGenerationChunk({
message: new AIMessageChunk({
content: generationMessage.content,
additional_kwargs: generationMessage.additional_kwargs,
tool_call_chunks: toolCallChunks,
}),
text: generationMessage.content,
});
} else {
const response = await this.completionWithRetry(
{
...params,
messages: messagesMapped,
stream: true,
},
);
let role = "";
const toolCall: {
id: string;
name: string;
index: number;
type: "tool_call_chunk";
}[] = [];
for await (const data of response) {
const choice = data?.choices[0];
if (!choice) {
continue;
}
// The `role` field is populated in the first delta of the response
// but is not present in subsequent deltas. Extract it when available.
if (choice.delta?.role) {
role = choice.delta.role;
}

const { message, toolCallData } = _convertDeltaToMessageChunk(
{
signal: options?.signal,
headers: options?.headers,
}
...choice.delta,
role,
} ?? {},
choice.index
);
let role = "";
for await (const data of response) {
const choice = data?.choices[0];
if (!choice) {
continue;
}
// The `role` field is populated in the first delta of the response
// but is not present in subsequent deltas. Extract it when available.
if (choice.delta?.role) {
role = choice.delta.role;
}
const chunk = new ChatGenerationChunk({
message: _convertDeltaToMessageChunk(
{
...choice.delta,
role,
} ?? {}
),
text: choice.delta.content ?? "",
generationInfo: {
finishReason: choice.finish_reason,
},

if (toolCallData) {
// First, ensure the ID is not already present in toolCall
const newToolCallData = toolCallData.filter((tc) =>
toolCall.every((t) => t.id !== tc.id)
);
toolCall.push(...newToolCallData);

// Yield here, ensuring the ID and name fields are only yielded once.
yield new ChatGenerationChunk({
message: new AIMessageChunk({
content: "",
tool_call_chunks: newToolCallData,
}),
text: "",
});
yield chunk;
void runManager?.handleLLMNewToken(chunk.text ?? "");
}
if (options.signal?.aborted) {
throw new Error("AbortError");
}

const chunk = new ChatGenerationChunk({
message,
text: choice.delta.content ?? "",
generationInfo: {
finishReason: choice.finish_reason,
},
});
yield chunk;
void runManager?.handleLLMNewToken(chunk.text ?? "");
}

if (options.signal?.aborted) {
throw new Error("AbortError");
}
}

Expand Down Expand Up @@ -518,7 +573,7 @@ export class ChatGroq extends BaseChatModel<
completion_tokens: completionTokens,
prompt_tokens: promptTokens,
total_tokens: totalTokens,
} = data.usage as ChatCompletion.Usage;
} = data.usage as CompletionsAPI.CompletionUsage;

if (completionTokens) {
tokenUsage.completionTokens =
Expand Down
48 changes: 47 additions & 1 deletion libs/langchain-groq/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
import { test } from "@jest/globals";
import { AIMessage, HumanMessage, ToolMessage } from "@langchain/core/messages";
import {
AIMessage,
AIMessageChunk,
HumanMessage,
ToolMessage,
} from "@langchain/core/messages";
import { tool } from "@langchain/core/tools";
import { z } from "zod";
import { concat } from "@langchain/core/utils/stream";
import { ChatGroq } from "../chat_models.js";

test("invoke", async () => {
Expand Down Expand Up @@ -197,3 +205,41 @@ test("Few shotting with tool calls", async () => {
// console.log(res);
expect(res.content).toContain("24");
});

test("Groq can stream tool calls", async () => {
const model = new ChatGroq({
model: "llama-3.1-70b-versatile",
temperature: 0,
});

const weatherTool = tool((_) => "The temperature is 24 degrees with hail.", {
name: "get_current_weather",
schema: z.object({
location: z
.string()
.describe("The location to get the current weather for."),
}),
description: "Get the current weather in a given location.",
});

const modelWithTools = model.bindTools([weatherTool]);

const stream = await modelWithTools.stream(
"What is the weather in San Francisco?"
);

let finalMessage: AIMessageChunk | undefined;
for await (const chunk of stream) {
finalMessage = !finalMessage ? chunk : concat(finalMessage, chunk);
}

expect(finalMessage).toBeDefined();
if (!finalMessage) return;

expect(finalMessage.tool_calls?.[0]).toBeDefined();
if (!finalMessage.tool_calls?.[0]) return;

expect(finalMessage.tool_calls?.[0].name).toBe("get_current_weather");
expect(finalMessage.tool_calls?.[0].args).toHaveProperty("location");
expect(finalMessage.tool_calls?.[0].id).toBeDefined();
});
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class ChatGroqStandardIntegrationTests extends ChatModelIntegrationTests<
chatModelHasToolCalling: true,
chatModelHasStructuredOutput: true,
constructorArgs: {
model: "mixtral-8x7b-32768",
model: "llama-3.1-70b-versatile",
},
});
}
Expand Down
Loading

0 comments on commit a614d1d

Please sign in to comment.