Skip to content

Commit

Permalink
Adds base URL support for Cloudflare WorkersAI (langchain-ai#2970)
Browse files Browse the repository at this point in the history
* Adds base URL support

* Cleanup
  • Loading branch information
jacoblee93 authored Oct 19, 2023
1 parent 23a2b24 commit 5ce6a78
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 16 deletions.
2 changes: 2 additions & 0 deletions examples/src/models/chat/integration_cloudflare_workersai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ const model = new ChatCloudflareWorkersAI({
model: "@cf/meta/llama-2-7b-chat-int8", // Default value
cloudflareAccountId: process.env.CLOUDFLARE_ACCOUNT_ID,
cloudflareApiToken: process.env.CLOUDFLARE_API_TOKEN,
// Pass a custom base URL to use Cloudflare AI Gateway
// baseUrl: `https://gateway.ai.cloudflare.com/v1/{YOUR_ACCOUNT_ID}/{GATEWAY_NAME}/workers-ai/`,
});

const response = await model.invoke([
Expand Down
2 changes: 2 additions & 0 deletions examples/src/models/llm/cloudflare_workersai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ const model = new CloudflareWorkersAI({
model: "@cf/meta/llama-2-7b-chat-int8", // Default value
cloudflareAccountId: process.env.CLOUDFLARE_ACCOUNT_ID,
cloudflareApiToken: process.env.CLOUDFLARE_API_TOKEN,
// Pass a custom base URL to use Cloudflare AI Gateway
// baseUrl: `https://gateway.ai.cloudflare.com/v1/{YOUR_ACCOUNT_ID}/{GATEWAY_NAME}/workers-ai/`,
});

const response = await model.invoke(
Expand Down
11 changes: 8 additions & 3 deletions langchain/src/chat_models/cloudflare_workersai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ export class ChatCloudflareWorkersAI

cloudflareApiToken?: string;

baseUrl?: string;
baseUrl: string;

constructor(fields?: CloudflareWorkersAIInput & BaseChatModelParams) {
super(fields ?? {});
Expand All @@ -44,7 +44,12 @@ export class ChatCloudflareWorkersAI
this.cloudflareApiToken =
fields?.cloudflareApiToken ??
getEnvironmentVariable("CLOUDFLARE_API_TOKEN");
this.baseUrl = fields?.baseUrl;
this.baseUrl =
fields?.baseUrl ??
`https://api.cloudflare.com/client/v4/accounts/${this.cloudflareAccountId}/ai/run`;
if (this.baseUrl.endsWith("/")) {
this.baseUrl = this.baseUrl.slice(0, -1);
}
}

_llmType() {
Expand Down Expand Up @@ -119,7 +124,7 @@ export class ChatCloudflareWorkersAI
): Promise<string> {
this.validateEnvironment();

const url = `https://api.cloudflare.com/client/v4/accounts/${this.cloudflareAccountId}/ai/run/${this.model}`;
const url = `${this.baseUrl}/${this.model}`;
const headers = {
Authorization: `Bearer ${this.cloudflareApiToken}`,
"Content-Type": "application/json",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
SystemMessagePromptTemplate,
} from "../../prompts/index.js";
import { ChatCloudflareWorkersAI } from "../cloudflare_workersai.js";
import { getEnvironmentVariable } from "../../util/env.js";

describe("ChatCloudflareWorkersAI", () => {
test("call", async () => {
Expand Down Expand Up @@ -70,4 +71,26 @@ describe("ChatCloudflareWorkersAI", () => {

console.log(responseA.generations);
});

test.skip("custom base url", async () => {
const chat = new ChatCloudflareWorkersAI({
baseUrl: `https://gateway.ai.cloudflare.com/v1/${getEnvironmentVariable(
"CLOUDFLARE_ACCOUNT_ID"
)}/lang-chainjs/workers-ai/`,
});

const chatPrompt = ChatPromptTemplate.fromMessages([
HumanMessagePromptTemplate.fromTemplate(`Hi, my name is Joe!`),
AIMessagePromptTemplate.fromTemplate(`Nice to meet you, Joe!`),
HumanMessagePromptTemplate.fromTemplate("{text}"),
]);

const responseA = await chat.generatePrompt([
await chatPrompt.formatPromptValue({
text: "What did I just say my name was?",
}),
]);

console.log(responseA.generations);
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ const dummyMessages = [
),
];

test("should respond with the proper schema", async () => {
test.skip("should respond with the proper schema", async () => {
const vectorStore = await HNSWLib.fromTexts(
[" "],
[{ id: 1 }],
Expand Down
31 changes: 19 additions & 12 deletions langchain/src/llms/cloudflare_workersai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ export class CloudflareWorkersAI

cloudflareApiToken?: string;

baseUrl?: string;
baseUrl: string;

static lc_name() {
return "CloudflareWorkersAI";
Expand All @@ -44,22 +44,29 @@ export class CloudflareWorkersAI
this.cloudflareApiToken =
fields?.cloudflareApiToken ??
getEnvironmentVariable("CLOUDFLARE_API_TOKEN");
this.baseUrl = fields?.baseUrl;
this.baseUrl =
fields?.baseUrl ??
`https://api.cloudflare.com/client/v4/accounts/${this.cloudflareAccountId}/ai/run`;
if (this.baseUrl.endsWith("/")) {
this.baseUrl = this.baseUrl.slice(0, -1);
}
}

/**
* Method to validate the environment.
*/
validateEnvironment() {
if (!this.cloudflareAccountId) {
throw new Error(
`No Cloudflare account ID found. Please provide it when instantiating the CloudflareWorkersAI class, or set it as "CLOUDFLARE_ACCOUNT_ID" in your environment variables.`
);
}
if (!this.cloudflareApiToken) {
throw new Error(
`No Cloudflare API key found. Please provide it when instantiating the CloudflareWorkersAI class, or set it as "CLOUDFLARE_API_KEY" in your environment variables.`
);
if (this.baseUrl === undefined) {
if (!this.cloudflareAccountId) {
throw new Error(
`No Cloudflare account ID found. Please provide it when instantiating the CloudflareWorkersAI class, or set it as "CLOUDFLARE_ACCOUNT_ID" in your environment variables.`
);
}
if (!this.cloudflareApiToken) {
throw new Error(
`No Cloudflare API key found. Please provide it when instantiating the CloudflareWorkersAI class, or set it as "CLOUDFLARE_API_KEY" in your environment variables.`
);
}
}
}

Expand Down Expand Up @@ -96,7 +103,7 @@ export class CloudflareWorkersAI
): Promise<string> {
this.validateEnvironment();

const url = `https://api.cloudflare.com/client/v4/accounts/${this.cloudflareAccountId}/ai/run/${this.model}`;
const url = `${this.baseUrl}/${this.model}`;
const headers = {
Authorization: `Bearer ${this.cloudflareApiToken}`,
"Content-Type": "application/json",
Expand Down
11 changes: 11 additions & 0 deletions langchain/src/llms/tests/cloudflare_workersai.int.test.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
import { test } from "@jest/globals";
import { CloudflareWorkersAI } from "../cloudflare_workersai.js";
import { getEnvironmentVariable } from "../../util/env.js";

test("Test CloudflareWorkersAI", async () => {
const model = new CloudflareWorkersAI({});
const res = await model.call("1 + 1 =");
console.log(res);
}, 50000);

test.skip("Test custom base url", async () => {
const model = new CloudflareWorkersAI({
baseUrl: `https://gateway.ai.cloudflare.com/v1/${getEnvironmentVariable(
"CLOUDFLARE_ACCOUNT_ID"
)}/lang-chainjs/workers-ai/`,
});
const res = await model.call("1 + 1 =");
console.log(res);
});

0 comments on commit 5ce6a78

Please sign in to comment.