Skip to content

Commit

Permalink
Add AnthropicFunctionsBedrock (langchain-ai#2892)
Browse files Browse the repository at this point in the history
* x

* wip

* Composition over inheritance

* Add test

* Revert Bedrock change

* Fix lint

* Further reduce changes

---------

Co-authored-by: jacoblee93 <[email protected]>
  • Loading branch information
efriis and jacoblee93 authored Oct 12, 2023
1 parent 20ceea0 commit d7dded6
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 5 deletions.
41 changes: 37 additions & 4 deletions langchain/src/experimental/chat_models/anthropic_functions.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import { XMLParser } from "fast-xml-parser";

import { BaseChatModelParams } from "../../chat_models/base.js";
import { BaseChatModel, BaseChatModelParams } from "../../chat_models/base.js";
import { CallbackManagerForLLMRun } from "../../callbacks/manager.js";
import {
AIMessage,
BaseMessage,
ChatGenerationChunk,
ChatResult,
SystemMessage,
} from "../../schema/index.js";
Expand Down Expand Up @@ -44,13 +45,44 @@ export interface ChatAnthropicFunctionsCallOptions
tools?: StructuredTool[];
}

export class AnthropicFunctions extends ChatAnthropic<ChatAnthropicFunctionsCallOptions> {
export type AnthropicFunctionsInput = Partial<AnthropicInput> &
BaseChatModelParams & {
llm?: BaseChatModel;
};

export class AnthropicFunctions extends BaseChatModel<ChatAnthropicFunctionsCallOptions> {
llm: BaseChatModel;

stopSequences?: string[];

lc_namespace = ["langchain", "experimental", "chat_models"];

static lc_name(): string {
return "AnthropicFunctions";
}

constructor(fields?: Partial<AnthropicInput> & BaseChatModelParams) {
constructor(fields?: AnthropicFunctionsInput) {
super(fields ?? {});
this.llm = fields?.llm ?? new ChatAnthropic(fields);
this.stopSequences =
fields?.stopSequences ?? (this.llm as ChatAnthropic).stopSequences;
}

invocationParams() {
return this.llm.invocationParams();
}

/** @ignore */
_identifyingParams() {
return this.llm._identifyingParams();
}

async *_streamResponseChunks(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
yield* this.llm._streamResponseChunks(messages, options, runManager);
}

async _generate(
Expand Down Expand Up @@ -109,12 +141,13 @@ export class AnthropicFunctions extends ChatAnthropic<ChatAnthropicFunctionsCall
`If "function_call" is provided, "functions" must also be.`
);
}
const chatResult = await super._generate(
const chatResult = await this.llm._generate(
promptMessages,
options,
runManager
);
const chatGenerationContent = chatResult.generations[0].message.content;

if (forced) {
const parser = new XMLParser();
const result = parser.parse(`${chatGenerationContent}</tool_input>`);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
/* eslint-disable no-process-env */
/* eslint-disable @typescript-eslint/no-non-null-assertion */
import { test } from "@jest/globals";
import { HumanMessage } from "../../../schema/index.js";
import { BaseMessageChunk, HumanMessage } from "../../../schema/index.js";
import { AnthropicFunctions } from "../anthropic_functions.js";
import { ChatBedrock } from "../../../chat_models/bedrock.js";

test("Test AnthropicFunctions", async () => {
const chat = new AnthropicFunctions({ modelName: "claude-2" });
Expand All @@ -10,6 +12,18 @@ test("Test AnthropicFunctions", async () => {
console.log(JSON.stringify(res));
});

test("Test AnthropicFunctions streaming", async () => {
const chat = new AnthropicFunctions({ modelName: "claude-2" });
const message = new HumanMessage("Hello!");
const stream = await chat.stream([message]);
const chunks: BaseMessageChunk[] = [];
for await (const chunk of stream) {
console.log(chunk);
chunks.push(chunk);
}
expect(chunks.length).toBeGreaterThan(1);
});

test("Test AnthropicFunctions with functions", async () => {
const chat = new AnthropicFunctions({
modelName: "claude-2",
Expand Down Expand Up @@ -78,3 +92,48 @@ test("Test AnthropicFunctions with a forced function call", async () => {
const res = await chat.invoke([message]);
console.log(JSON.stringify(res));
});

test("Test AnthropicFunctions with a Bedrock model", async () => {
const chatBedrock = new ChatBedrock({
region: process.env.BEDROCK_AWS_REGION ?? "us-east-1",
model: "anthropic.claude-v2",
temperature: 0.1,
credentials: {
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
},
});
const model = new AnthropicFunctions({
llm: chatBedrock,
}).bind({
functions: [
{
name: "get_current_weather",
description: "Get the current weather in a given location",
parameters: {
type: "object",
properties: {
location: {
type: "string",
description: "The city and state, e.g. San Francisco, CA",
},
unit: { type: "string", enum: ["celsius", "fahrenheit"] },
},
required: ["location"],
},
},
],
// You can set the `function_call` arg to force the model to use a function
function_call: {
name: "get_current_weather",
},
});

const response = await model.invoke([
new HumanMessage({
content: "What's the weather in Boston?",
}),
]);

console.log(response);
});

0 comments on commit d7dded6

Please sign in to comment.