Skip to content

Commit

Permalink
google-common[minor]: Fix streaming tool calls (langchain-ai#6204)
Browse files Browse the repository at this point in the history
* google-common[minor]: Fix streaming tool calls

* lint format

* chore: lint files
  • Loading branch information
bracesproul authored Jul 25, 2024
1 parent 37321ef commit 0808ef6
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 98 deletions.
22 changes: 16 additions & 6 deletions libs/langchain-google-common/src/utils/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { v4 as uuidv4 } from "uuid";
import {
AIMessage,
AIMessageChunk,
AIMessageFields,
AIMessageChunkFields,
BaseMessage,
BaseMessageChunk,
BaseMessageFields,
Expand Down Expand Up @@ -566,7 +566,7 @@ export function chunkToString(chunk: BaseMessageChunk): string {
}

export function partToMessageChunk(part: GeminiPart): BaseMessageChunk {
const fields = partsToBaseMessageFields([part]);
const fields = partsToBaseMessageChunkFields([part]);
if (typeof fields.content === "string") {
return new AIMessageChunk(fields);
} else if (fields.content.every((item) => item.type === "text")) {
Expand Down Expand Up @@ -636,12 +636,15 @@ export function responseToBaseMessageFields(
response: GoogleLLMResponse
): BaseMessageFields {
const parts = responseToParts(response);
return partsToBaseMessageFields(parts);
return partsToBaseMessageChunkFields(parts);
}

export function partsToBaseMessageFields(parts: GeminiPart[]): AIMessageFields {
const fields: AIMessageFields = {
export function partsToBaseMessageChunkFields(
parts: GeminiPart[]
): AIMessageChunkFields {
const fields: AIMessageChunkFields = {
content: partsToMessageContent(parts),
tool_call_chunks: [],
tool_calls: [],
invalid_tool_calls: [],
};
Expand All @@ -650,6 +653,13 @@ export function partsToBaseMessageFields(parts: GeminiPart[]): AIMessageFields {
if (rawTools.length > 0) {
const tools = toolsRawToTools(rawTools);
for (const tool of tools) {
fields.tool_call_chunks?.push({
name: tool.function.name,
args: tool.function.arguments,
id: tool.id,
type: "tool_call_chunk",
});

try {
fields.tool_calls?.push({
name: tool.function.name,
Expand All @@ -661,7 +671,7 @@ export function partsToBaseMessageFields(parts: GeminiPart[]): AIMessageFields {
} catch (e: any) {
fields.invalid_tool_calls?.push({
name: tool.function.name,
args: JSON.parse(tool.function.arguments),
args: tool.function.arguments,
id: tool.id,
error: e.message,
type: "invalid_tool_call",
Expand Down
3 changes: 2 additions & 1 deletion libs/langchain-google-vertexai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@
"release-it": "^15.10.1",
"rollup": "^4.5.2",
"ts-jest": "^29.1.0",
"typescript": "<5.2.0"
"typescript": "<5.2.0",
"zod": "^3.22.4"
},
"publishConfig": {
"access": "public"
Expand Down
177 changes: 86 additions & 91 deletions libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,111 +11,70 @@ import {
SystemMessage,
ToolMessage,
} from "@langchain/core/messages";
import { ChatVertexAI } from "../chat_models.js";
import { tool } from "@langchain/core/tools";
import { concat } from "@langchain/core/utils/stream";
import { z } from "zod";
import { GeminiTool } from "../types.js";
import { ChatVertexAI } from "../chat_models.js";

describe("GAuth Chat", () => {
test("invoke", async () => {
const model = new ChatVertexAI();
try {
const res = await model.invoke("What is 1 + 1?");
expect(res).toBeDefined();
expect(res._getType()).toEqual("ai");

const aiMessage = res as AIMessageChunk;
expect(aiMessage.content).toBeDefined();

expect(typeof aiMessage.content).toBe("string");
const text = aiMessage.content as string;
expect(text).toMatch(/(1 + 1 (equals|is|=) )?2.? ?/);
const res = await model.invoke("What is 1 + 1?");
expect(res).toBeDefined();
expect(res._getType()).toEqual("ai");

/*
expect(aiMessage.content.length).toBeGreaterThan(0);
expect(aiMessage.content[0]).toBeDefined();
const content = aiMessage.content[0] as MessageContentComplex;
expect(content).toHaveProperty("type");
expect(content.type).toEqual("text");
const aiMessage = res as AIMessageChunk;
expect(aiMessage.content).toBeDefined();

const textContent = content as MessageContentText;
expect(textContent.text).toBeDefined();
expect(textContent.text).toEqual("2");
*/
} catch (e) {
console.error(e);
throw e;
}
expect(typeof aiMessage.content).toBe("string");
const text = aiMessage.content as string;
expect(text).toMatch(/(1 + 1 (equals|is|=) )?2.? ?/);
});

test("generate", async () => {
const model = new ChatVertexAI();
try {
const messages: BaseMessage[] = [
new SystemMessage(
"You will reply to all requests to flip a coin with either H, indicating heads, or T, indicating tails."
),
new HumanMessage("Flip it"),
new AIMessage("T"),
new HumanMessage("Flip the coin again"),
];
const res = await model.predictMessages(messages);
expect(res).toBeDefined();
expect(res._getType()).toEqual("ai");

const aiMessage = res as AIMessageChunk;
expect(aiMessage.content).toBeDefined();

expect(typeof aiMessage.content).toBe("string");
const text = aiMessage.content as string;
expect(["H", "T"]).toContainEqual(text);

/*
expect(aiMessage.content.length).toBeGreaterThan(0);
expect(aiMessage.content[0]).toBeDefined();
const messages: BaseMessage[] = [
new SystemMessage(
"You will reply to all requests to flip a coin with either H, indicating heads, or T, indicating tails."
),
new HumanMessage("Flip it"),
new AIMessage("T"),
new HumanMessage("Flip the coin again"),
];
const res = await model.predictMessages(messages);
expect(res).toBeDefined();
expect(res._getType()).toEqual("ai");

const content = aiMessage.content[0] as MessageContentComplex;
expect(content).toHaveProperty("type");
expect(content.type).toEqual("text");
const aiMessage = res as AIMessageChunk;
expect(aiMessage.content).toBeDefined();

const textContent = content as MessageContentText;
expect(textContent.text).toBeDefined();
expect(["H", "T"]).toContainEqual(textContent.text);
*/
} catch (e) {
console.error(e);
throw e;
}
expect(typeof aiMessage.content).toBe("string");
const text = aiMessage.content as string;
expect(["H", "T"]).toContainEqual(text);
});

test("stream", async () => {
const model = new ChatVertexAI();
try {
const input: BaseLanguageModelInput = new ChatPromptValue([
new SystemMessage(
"You will reply to all requests to flip a coin with either H, indicating heads, or T, indicating tails."
),
new HumanMessage("Flip it"),
new AIMessage("T"),
new HumanMessage("Flip the coin again"),
]);
const res = await model.stream(input);
const resArray: BaseMessageChunk[] = [];
for await (const chunk of res) {
resArray.push(chunk);
}
expect(resArray).toBeDefined();
expect(resArray.length).toBeGreaterThanOrEqual(1);

const lastChunk = resArray[resArray.length - 1];
expect(lastChunk).toBeDefined();
expect(lastChunk._getType()).toEqual("ai");
const aiChunk = lastChunk as AIMessageChunk;
console.log(aiChunk);

console.log(JSON.stringify(resArray, null, 2));
} catch (e) {
console.error(e);
throw e;
const input: BaseLanguageModelInput = new ChatPromptValue([
new SystemMessage(
"You will reply to all requests to flip a coin with either H, indicating heads, or T, indicating tails."
),
new HumanMessage("Flip it"),
new AIMessage("T"),
new HumanMessage("Flip the coin again"),
]);
const res = await model.stream(input);
const resArray: BaseMessageChunk[] = [];
for await (const chunk of res) {
resArray.push(chunk);
}
expect(resArray).toBeDefined();
expect(resArray.length).toBeGreaterThanOrEqual(1);

const lastChunk = resArray[resArray.length - 1];
expect(lastChunk).toBeDefined();
expect(lastChunk._getType()).toEqual("ai");
});

test("function", async () => {
Expand Down Expand Up @@ -209,7 +168,7 @@ describe("GAuth Chat", () => {
for await (const chunk of res) {
resArray.push(chunk);
}
console.log(JSON.stringify(resArray, null, 2));
// console.log(JSON.stringify(resArray, null, 2));
});

test("withStructuredOutput", async () => {
Expand Down Expand Up @@ -249,7 +208,7 @@ test("Stream token count usage_metadata", async () => {
res = res.concat(chunk);
}
}
console.log(res);
// console.log(res);
expect(res?.usage_metadata).toBeDefined();
if (!res?.usage_metadata) {
return;
Expand All @@ -276,7 +235,7 @@ test("streamUsage excludes token usage", async () => {
res = res.concat(chunk);
}
}
console.log(res);
// console.log(res);
expect(res?.usage_metadata).not.toBeDefined();
});

Expand All @@ -286,7 +245,7 @@ test("Invoke token count usage_metadata", async () => {
maxOutputTokens: 10,
});
const res = await model.invoke("Why is the sky blue? Be concise.");
console.log(res);
// console.log(res);
expect(res?.usage_metadata).toBeDefined();
if (!res?.usage_metadata) {
return;
Expand Down Expand Up @@ -322,3 +281,39 @@ test("Streaming true constructor param will stream", async () => {

expect(totalTokenCount).toBeGreaterThan(1);
});

test("ChatGoogleGenerativeAI can stream tools", async () => {
const model = new ChatVertexAI({});

const weatherTool = tool(
(_) => "The weather in San Francisco today is 18 degrees and sunny.",
{
name: "current_weather_tool",
description: "Get the current weather for a given location.",
schema: z.object({
location: z.string().describe("The location to get the weather for."),
}),
}
);

const modelWithTools = model.bindTools([weatherTool]);
const stream = await modelWithTools.stream(
"Whats the weather like today in San Francisco?"
);
let finalChunk: AIMessageChunk | undefined;
for await (const chunk of stream) {
finalChunk = !finalChunk ? chunk : concat(finalChunk, chunk);
}

expect(finalChunk).toBeDefined();
if (!finalChunk) return;

const toolCalls = finalChunk.tool_calls;
expect(toolCalls).toBeDefined();
if (!toolCalls) {
throw new Error("tool_calls not in response");
}
expect(toolCalls.length).toBe(1);
expect(toolCalls[0].name).toBe("current_weather_tool");
expect(toolCalls[0].args).toHaveProperty("location");
});
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class ChatVertexAIStandardIntegrationTests extends ChatModelIntegrationTests<
Cls: ChatVertexAI,
chatModelHasToolCalling: true,
chatModelHasStructuredOutput: true,
invokeResponseType: AIMessageChunk,
constructorArgs: {
model: "gemini-1.5-pro",
},
Expand All @@ -32,6 +33,14 @@ class ChatVertexAIStandardIntegrationTests extends ChatModelIntegrationTests<
"Not implemented."
);
}

async testInvokeMoreComplexTools() {
this.skipTestMessage(
"testInvokeMoreComplexTools",
"ChatVertexAI",
"Google VertexAI does not support tool schemas where the object properties are not defined."
);
}
}

const testClass = new ChatVertexAIStandardIntegrationTests();
Expand Down
1 change: 1 addition & 0 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -11695,6 +11695,7 @@ __metadata:
rollup: ^4.5.2
ts-jest: ^29.1.0
typescript: <5.2.0
zod: ^3.22.4
languageName: unknown
linkType: soft

Expand Down

0 comments on commit 0808ef6

Please sign in to comment.