Skip to content

Commit

Permalink
Use invoke method for non-streaming calls (langchain-ai#2787)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 authored Oct 4, 2023
1 parent 88e503e commit ffcabb1
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 52 deletions.
73 changes: 47 additions & 26 deletions langchain/src/chat_models/bedrock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -158,28 +158,32 @@ export class ChatBedrock extends SimpleChatModel implements BaseBedrockInput {
*/
async _call(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
options: this["ParsedCallOptions"]
): Promise<string> {
const chunks = [];
for await (const chunk of this._streamResponseChunks(
messages,
options,
runManager
)) {
chunks.push(chunk);
}
return chunks.map((chunk) => chunk.text).join("");
const service = "bedrock-runtime";
const endpointHost =
this.endpointHost ?? `${service}.${this.region}.amazonaws.com`;
const provider = this.model.split(".")[0];
const response = await this._signedFetch(messages, options, {
bedrockMethod: "invoke",
endpointHost,
provider,
});
const json = await response.json();
const text = BedrockLLMInputOutputAdapter.prepareOutput(provider, json);
return text;
}

async *_streamResponseChunks(
async _signedFetch(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
const provider = this.model.split(".")[0];
const service = "bedrock-runtime";

fields: {
bedrockMethod: "invoke" | "invoke-with-response-stream";
endpointHost: string;
provider: string;
}
) {
const { bedrockMethod, endpointHost, provider } = fields;
const inputBody = BedrockLLMInputOutputAdapter.prepareInput(
provider,
convertMessagesToPromptAnthropic(messages),
Expand All @@ -189,13 +193,8 @@ export class ChatBedrock extends SimpleChatModel implements BaseBedrockInput {
this.modelKwargs
);

const endpointHost =
this.endpointHost ?? `${service}.${this.region}.amazonaws.com`;

const amazonMethod =
provider === "anthropic" ? "invoke-with-response-stream" : "invoke";
const url = new URL(
`https://${endpointHost}/model/${this.model}/${amazonMethod}`
`https://${endpointHost}/model/${this.model}/${bedrockMethod}`
);

const request = new HttpRequest({
Expand Down Expand Up @@ -232,12 +231,34 @@ export class ChatBedrock extends SimpleChatModel implements BaseBedrockInput {
method: signedRequest.method,
})
);
return response;
}

async *_streamResponseChunks(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
const provider = this.model.split(".")[0];
const service = "bedrock-runtime";

const endpointHost =
this.endpointHost ?? `${service}.${this.region}.amazonaws.com`;

const bedrockMethod =
provider === "anthropic" ? "invoke-with-response-stream" : "invoke";

const response = await this._signedFetch(messages, options, {
bedrockMethod,
endpointHost,
provider,
});

if (response.status < 200 || response.status >= 300) {
throw Error(
`Failed to access underlying url '${url}': got ${response.status} ${
response.statusText
}: ${await response.text()}`
`Failed to access underlying url '${endpointHost}': got ${
response.status
} ${response.statusText}: ${await response.text()}`
);
}

Expand Down
73 changes: 47 additions & 26 deletions langchain/src/llms/bedrock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -96,28 +96,32 @@ export class Bedrock extends LLM implements BaseBedrockInput {
*/
async _call(
prompt: string,
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
options: this["ParsedCallOptions"]
): Promise<string> {
const chunks = [];
for await (const chunk of this._streamResponseChunks(
prompt,
options,
runManager
)) {
chunks.push(chunk);
}
return chunks.map((chunk) => chunk.text).join("");
const service = "bedrock-runtime";
const endpointHost =
this.endpointHost ?? `${service}.${this.region}.amazonaws.com`;
const provider = this.model.split(".")[0];
const response = await this._signedFetch(prompt, options, {
bedrockMethod: "invoke",
endpointHost,
provider,
});
const json = await response.json();
const text = BedrockLLMInputOutputAdapter.prepareOutput(provider, json);
return text;
}

async *_streamResponseChunks(
async _signedFetch(
prompt: string,
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<GenerationChunk> {
const provider = this.model.split(".")[0];
const service = "bedrock-runtime";

fields: {
bedrockMethod: "invoke" | "invoke-with-response-stream";
endpointHost: string;
provider: string;
}
) {
const { bedrockMethod, endpointHost, provider } = fields;
const inputBody = BedrockLLMInputOutputAdapter.prepareInput(
provider,
prompt,
Expand All @@ -127,13 +131,8 @@ export class Bedrock extends LLM implements BaseBedrockInput {
this.modelKwargs
);

const endpointHost =
this.endpointHost ?? `${service}.${this.region}.amazonaws.com`;

const amazonMethod =
provider === "anthropic" ? "invoke-with-response-stream" : "invoke";
const url = new URL(
`https://${endpointHost}/model/${this.model}/${amazonMethod}`
`https://${endpointHost}/model/${this.model}/${bedrockMethod}`
);

const request = new HttpRequest({
Expand Down Expand Up @@ -170,12 +169,34 @@ export class Bedrock extends LLM implements BaseBedrockInput {
method: signedRequest.method,
})
);
return response;
}

async *_streamResponseChunks(
prompt: string,
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<GenerationChunk> {
const provider = this.model.split(".")[0];
const bedrockMethod =
provider === "anthropic" ? "invoke-with-response-stream" : "invoke";

const service = "bedrock-runtime";
const endpointHost =
this.endpointHost ?? `${service}.${this.region}.amazonaws.com`;

// Send request to AWS using the low-level fetch API
const response = await this._signedFetch(prompt, options, {
bedrockMethod,
endpointHost,
provider,
});

if (response.status < 200 || response.status >= 300) {
throw Error(
`Failed to access underlying url '${url}': got ${response.status} ${
response.statusText
}: ${await response.text()}`
`Failed to access underlying url '${endpointHost}': got ${
response.status
} ${response.statusText}: ${await response.text()}`
);
}

Expand Down
50 changes: 50 additions & 0 deletions langchain/src/llms/tests/bedrock.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,53 @@ test("Test Bedrock LLM: Claude-v2", async () => {
expect(typeof res).toBe("string");
console.log(res);
});

test("Test Bedrock LLM streaming: AI21", async () => {
const region = process.env.BEDROCK_AWS_REGION!;
const model = "ai21.j2-grande-instruct";
const prompt = "Human: What is your name?";

const bedrock = new Bedrock({
maxTokens: 20,
region,
model,
maxRetries: 0,
credentials: {
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
},
});

const stream = await bedrock.stream(prompt);
const chunks = [];
for await (const chunk of stream) {
console.log(chunk);
chunks.push(chunk);
}
expect(chunks.length).toEqual(1);
});

test("Test Bedrock LLM streaming: Claude-v2", async () => {
const region = process.env.BEDROCK_AWS_REGION!;
const model = "anthropic.claude-v2";
const prompt = "Human: What is your name?\n\nAssistant:";

const bedrock = new Bedrock({
maxTokens: 20,
region,
model,
maxRetries: 0,
credentials: {
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
},
});

const stream = await bedrock.stream(prompt);
const chunks = [];
for await (const chunk of stream) {
console.log(chunk);
chunks.push(chunk);
}
expect(chunks.length).toBeGreaterThan(1);
});

0 comments on commit ffcabb1

Please sign in to comment.