Skip to content

Commit

Permalink
mistralai[patch]: Translate tool call ids to mistral compat format (l…
Browse files Browse the repository at this point in the history
…angchain-ai#6217)

* mistralai[patch]: Translate tool call ids to mistral compat format

* chore: lint files

* chore: lint files

* nits

* properly handle

* Update libs/langchain-mistralai/src/chat_models.ts
  • Loading branch information
bracesproul authored Jul 25, 2024
1 parent 9ac2f20 commit 00641e7
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 3 deletions.
19 changes: 17 additions & 2 deletions libs/langchain-mistralai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ import {
} from "@langchain/core/runnables";
import { zodToJsonSchema } from "zod-to-json-schema";
import { ToolCallChunk } from "@langchain/core/messages/tool";
import { _convertToolCallIdToMistralCompatible } from "./utils.js";

interface TokenUsage {
completionTokens?: number;
Expand Down Expand Up @@ -199,7 +200,10 @@ function convertMessagesToMistralMessages(
const getTools = (message: BaseMessage): MistralAIToolCalls[] | undefined => {
if (isAIMessage(message) && !!message.tool_calls?.length) {
return message.tool_calls
.map((toolCall) => ({ ...toolCall, id: toolCall.id }))
.map((toolCall) => ({
...toolCall,
id: _convertToolCallIdToMistralCompatible(toolCall.id ?? ""),
}))
.map(convertLangChainToolCallToOpenAI) as MistralAIToolCalls[];
}
if (!message.additional_kwargs.tool_calls?.length) {
Expand All @@ -208,7 +212,7 @@ function convertMessagesToMistralMessages(
const toolCalls: Omit<OpenAIToolCall, "index">[] =
message.additional_kwargs.tool_calls;
return toolCalls?.map((toolCall) => ({
id: toolCall.id,
id: _convertToolCallIdToMistralCompatible(toolCall.id),
type: "function",
function: toolCall.function,
}));
Expand All @@ -217,6 +221,17 @@ function convertMessagesToMistralMessages(
return messages.map((message) => {
const toolCalls = getTools(message);
const content = toolCalls === undefined ? getContent(message.content) : "";
if ("tool_call_id" in message && typeof message.tool_call_id === "string") {
return {
role: getRole(message._getType()),
content,
name: message.name,
tool_call_id: _convertToolCallIdToMistralCompatible(
message.tool_call_id
),
};
}

return {
role: getRole(message._getType()),
content,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class ChatMistralAIStandardUnitTests extends ChatModelUnitTests<

expectedLsParams(): Partial<LangSmithParams> {
console.warn(
"Overriding testStandardParams. ChatCloudflareWorkersAI does not support stop sequences."
"Overriding testStandardParams. ChatMistralAI does not support stop sequences."
);
return {
ls_provider: "string",
Expand Down
29 changes: 29 additions & 0 deletions libs/langchain-mistralai/src/tests/chat_models.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import {
_isValidMistralToolCallId,
_convertToolCallIdToMistralCompatible,
} from "../utils.js";

describe("Mistral Tool Call ID Conversion", () => {
test("valid and invalid Mistral tool call IDs", () => {
expect(_isValidMistralToolCallId("ssAbar4Dr")).toBe(true);
expect(_isValidMistralToolCallId("abc123")).toBe(false);
expect(_isValidMistralToolCallId("call_JIIjI55tTipFFzpcP8re3BpM")).toBe(
false
);
});

test("tool call ID conversion", () => {
const resultMap: Record<string, string> = {
ssAbar4Dr: "ssAbar4Dr",
abc123: "0001yoN1K",
call_JIIjI55tTipFFzpcP8re3BpM: "0001sqrj5",
12345: "00003akVR",
};

for (const [inputId, expectedOutput] of Object.entries(resultMap)) {
const convertedId = _convertToolCallIdToMistralCompatible(inputId);
expect(convertedId).toBe(expectedOutput);
expect(_isValidMistralToolCallId(convertedId)).toBe(true);
}
});
});
46 changes: 46 additions & 0 deletions libs/langchain-mistralai/src/utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Mistral enforces a specific pattern for tool call IDs
const TOOL_CALL_ID_PATTERN = /^[a-zA-Z0-9]{9}$/;

export function _isValidMistralToolCallId(toolCallId: string): boolean {
return TOOL_CALL_ID_PATTERN.test(toolCallId);
}

function _base62Encode(num: number): string {
let numCopy = num;
const base62 =
"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
if (numCopy === 0) return base62[0];
const arr: string[] = [];
const base = base62.length;
while (numCopy) {
arr.push(base62[numCopy % base]);
numCopy = Math.floor(numCopy / base);
}
return arr.reverse().join("");
}

function _simpleHash(str: string): number {
let hash = 0;
for (let i = 0; i < str.length; i += 1) {
const char = str.charCodeAt(i);
hash = (hash << 5) - hash + char;
hash &= hash; // Convert to 32-bit integer
}
return Math.abs(hash);
}

export function _convertToolCallIdToMistralCompatible(
toolCallId: string
): string {
if (_isValidMistralToolCallId(toolCallId)) {
return toolCallId;
} else {
const hash = _simpleHash(toolCallId);
const base62Str = _base62Encode(hash);
if (base62Str.length >= 9) {
return base62Str.slice(0, 9);
} else {
return base62Str.padStart(9, "0");
}
}
}

0 comments on commit 00641e7

Please sign in to comment.