Skip to content

Commit

Permalink
AI21 Labs Integration (langchain-ai#1505)
Browse files Browse the repository at this point in the history
* AI21 Labs Integration

* Yarn formatting for unit test

* Fix entrypoint, polish, use AsyncCaller, add tests

---------

Co-authored-by: jacoblee93 <[email protected]>
  • Loading branch information
paaatrrrick and jacoblee93 authored Jun 3, 2023
1 parent e5356ac commit c753c19
Show file tree
Hide file tree
Showing 15 changed files with 277 additions and 0 deletions.
8 changes: 8 additions & 0 deletions docs/docs/modules/models/llms/integrations.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,14 @@ import SageMakerEndpointExample from "@examples/models/llm/sagemaker_endpoint.ts

<CodeBlock language="typescript">{SageMakerEndpointExample}</CodeBlock>

## `AI21`

You can get started with AI21Labs' Jurassic family of models, as well as see a full list of available foundational models, by signing up for an API key [on their website](https://www.ai21.com/).

import AI21Example from "@examples/models/llm/ai21.ts";

<CodeBlock language="typescript">{AI21Example}</CodeBlock>

## Additional LLM Implementations

### `PromptLayerOpenAI`
Expand Down
15 changes: 15 additions & 0 deletions examples/src/models/llm/ai21.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import { AI21 } from "langchain/llms/ai21";

const model = new AI21({
ai21ApiKey: "YOUR_AI21_API_KEY", // Or set as process.env.AI21_API_KEY
});

const res = await model.call(`Translate "I love programming" into German.`);

console.log({ res });

/*
{
res: "\nIch liebe das Programmieren."
}
*/
3 changes: 3 additions & 0 deletions langchain/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ llms/base.d.ts
llms/openai.cjs
llms/openai.js
llms/openai.d.ts
llms/ai21.cjs
llms/ai21.js
llms/ai21.d.ts
llms/cohere.cjs
llms/cohere.js
llms/cohere.d.ts
Expand Down
8 changes: 8 additions & 0 deletions langchain/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@
"llms/openai.cjs",
"llms/openai.js",
"llms/openai.d.ts",
"llms/ai21.cjs",
"llms/ai21.js",
"llms/ai21.d.ts",
"llms/cohere.cjs",
"llms/cohere.js",
"llms/cohere.d.ts",
Expand Down Expand Up @@ -811,6 +814,11 @@
"import": "./llms/openai.js",
"require": "./llms/openai.cjs"
},
"./llms/ai21": {
"types": "./llms/ai21.d.ts",
"import": "./llms/ai21.js",
"require": "./llms/ai21.cjs"
},
"./llms/cohere": {
"types": "./llms/cohere.d.ts",
"import": "./llms/cohere.js",
Expand Down
1 change: 1 addition & 0 deletions langchain/scripts/create-entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ const entrypoints = {
"llms/load": "llms/load",
"llms/base": "llms/base",
"llms/openai": "llms/openai",
"llms/ai21": "llms/ai21",
"llms/cohere": "llms/cohere",
"llms/hf": "llms/hf",
"llms/replicate": "llms/replicate",
Expand Down
183 changes: 183 additions & 0 deletions langchain/src/llms/ai21.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import { LLM, BaseLLMParams } from "./base.js";
import { getEnvironmentVariable } from "../util/env.js";

export type AI21PenaltyData = {
scale: number;
applyToWhitespaces: boolean;
applyToPunctuations: boolean;
applyToNumbers: boolean;
applyToStopwords: boolean;
applyToEmojis: boolean;
};

export interface AI21Input extends BaseLLMParams {
ai21ApiKey?: string;
model?: string;
temperature?: number;
minTokens?: number;
maxTokens?: number;
topP?: number;
presencePenalty?: AI21PenaltyData;
countPenalty?: AI21PenaltyData;
frequencyPenalty?: AI21PenaltyData;
numResults?: number;
logitBias?: Record<string, number>;
stop?: string[];
baseUrl?: string;
}

export class AI21 extends LLM implements AI21Input {
model = "j2-jumbo-instruct";

temperature = 0.7;

maxTokens = 1024;

minTokens = 0;

topP = 1;

presencePenalty = AI21.getDefaultAI21PenaltyData();

countPenalty = AI21.getDefaultAI21PenaltyData();

frequencyPenalty = AI21.getDefaultAI21PenaltyData();

numResults = 1;

logitBias?: Record<string, number>;

ai21ApiKey?: string;

stop?: string[];

baseUrl?: string;

constructor(fields?: AI21Input) {
super(fields ?? {});

this.model = fields?.model ?? this.model;
this.temperature = fields?.temperature ?? this.temperature;
this.maxTokens = fields?.maxTokens ?? this.maxTokens;
this.minTokens = fields?.minTokens ?? this.minTokens;
this.topP = fields?.topP ?? this.topP;
this.presencePenalty = fields?.presencePenalty ?? this.presencePenalty;
this.countPenalty = fields?.countPenalty ?? this.countPenalty;
this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty;
this.numResults = fields?.numResults ?? this.numResults;
this.logitBias = fields?.logitBias;
this.ai21ApiKey =
fields?.ai21ApiKey ?? getEnvironmentVariable("AI21_API_KEY");
this.stop = fields?.stop;
this.baseUrl = fields?.baseUrl;
}

validateEnvironment() {
if (!this.ai21ApiKey) {
throw new Error(
`No AI21 API key found. Please set it as "AI21_API_KEY" in your environment variables.`
);
}
}

static getDefaultAI21PenaltyData(): AI21PenaltyData {
return {
scale: 0,
applyToWhitespaces: true,
applyToPunctuations: true,
applyToNumbers: true,
applyToStopwords: true,
applyToEmojis: true,
};
}

/** Get the type of LLM. */
_llmType() {
return "ai21";
}

/** Get the default parameters for calling AI21 API. */
get defaultParams() {
return {
temperature: this.temperature,
maxTokens: this.maxTokens,
minTokens: this.minTokens,
topP: this.topP,
presencePenalty: this.presencePenalty,
countPenalty: this.countPenalty,
frequencyPenalty: this.frequencyPenalty,
numResults: this.numResults,
logitBias: this.logitBias,
};
}

/** Get the identifying parameters for this LLM. */
get identifyingParams() {
return { ...this.defaultParams, model: this.model };
}

/** Call out to AI21's complete endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
let response = ai21._call("Tell me a joke.");
*/
async _call(
prompt: string,
options: this["ParsedCallOptions"]
): Promise<string> {
let stop = options?.stop;
this.validateEnvironment();
if (this.stop && stop && this.stop.length > 0 && stop.length > 0) {
throw new Error("`stop` found in both the input and default params.");
}
stop = this.stop ?? stop ?? [];

const baseUrl =
this.baseUrl ?? this.model === "j1-grande-instruct"
? "https://api.ai21.com/studio/v1/experimental"
: "https://api.ai21.com/studio/v1";

const url = `${baseUrl}/${this.model}/complete`;
const headers = {
Authorization: `Bearer ${this.ai21ApiKey}`,
"Content-Type": "application/json",
};
const data = { prompt, stopSequences: stop, ...this.defaultParams };
const responseData = await this.caller.callWithOptions(
{ signal: options.signal },
async () => {
const response = await fetch(url, {
method: "POST",
headers,
body: JSON.stringify(data),
signal: options.signal,
});
if (!response.ok) {
const error = new Error(
`AI21 call failed with status code ${response.status}`
);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(error as any).response = response;
throw error;
}
return response.json();
}
);

if (
!responseData.completions ||
responseData.completions.length === 0 ||
!responseData.completions[0].data
) {
throw new Error("No completions found in response");
}

return responseData.completions[0].data.text ?? "";
}
}
51 changes: 51 additions & 0 deletions langchain/src/llms/tests/ai21.int.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import { test, describe, expect } from "@jest/globals";
import { AI21 } from "../ai21.js";

describe("AI21", () => {
test("test call", async () => {
const ai21 = new AI21({});
const result = await ai21.call(
"What is a good name for a company that makes colorful socks?"
);
console.log({ result });
});

test("test translation call", async () => {
const ai21 = new AI21({});
const result = await ai21.call(
`Translate "I love programming" into German.`
);
console.log({ result });
});

test("test JSON output call", async () => {
const ai21 = new AI21({});
const result = await ai21.call(
`Output a JSON object with three string fields: "name", "birthplace", "bio".`
);
console.log({ result });
});

test("should abort the request", async () => {
const ai21 = new AI21({});
const controller = new AbortController();

await expect(() => {
const ret = ai21.call("Respond with an extremely verbose response", {
signal: controller.signal,
});
controller.abort();
return ret;
}).rejects.toThrow("AbortError: This operation was aborted");
});

test("throws an error when response status is not ok", async () => {
const ai21 = new AI21({
ai21ApiKey: "BAD_KEY",
});

await expect(ai21.call("Test prompt")).rejects.toThrow(
"AI21 call failed with status code 401"
);
});
});
1 change: 1 addition & 0 deletions langchain/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
"src/llms/load.ts",
"src/llms/base.ts",
"src/llms/openai.ts",
"src/llms/ai21.ts",
"src/llms/cohere.ts",
"src/llms/hf.ts",
"src/llms/replicate.ts",
Expand Down
1 change: 1 addition & 0 deletions test-exports-cf/src/entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ export * from "langchain/embeddings/fake";
export * from "langchain/embeddings/openai";
export * from "langchain/llms/base";
export * from "langchain/llms/openai";
export * from "langchain/llms/ai21";
export * from "langchain/prompts";
export * from "langchain/vectorstores/base";
export * from "langchain/vectorstores/memory";
Expand Down
1 change: 1 addition & 0 deletions test-exports-cjs/src/entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ const embeddings_fake = require("langchain/embeddings/fake");
const embeddings_openai = require("langchain/embeddings/openai");
const llms_base = require("langchain/llms/base");
const llms_openai = require("langchain/llms/openai");
const llms_ai21 = require("langchain/llms/ai21");
const prompts = require("langchain/prompts");
const vectorstores_base = require("langchain/vectorstores/base");
const vectorstores_memory = require("langchain/vectorstores/memory");
Expand Down
1 change: 1 addition & 0 deletions test-exports-cra/src/entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ export * from "langchain/embeddings/fake";
export * from "langchain/embeddings/openai";
export * from "langchain/llms/base";
export * from "langchain/llms/openai";
export * from "langchain/llms/ai21";
export * from "langchain/prompts";
export * from "langchain/vectorstores/base";
export * from "langchain/vectorstores/memory";
Expand Down
1 change: 1 addition & 0 deletions test-exports-esbuild/src/entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import * as embeddings_fake from "langchain/embeddings/fake";
import * as embeddings_openai from "langchain/embeddings/openai";
import * as llms_base from "langchain/llms/base";
import * as llms_openai from "langchain/llms/openai";
import * as llms_ai21 from "langchain/llms/ai21";
import * as prompts from "langchain/prompts";
import * as vectorstores_base from "langchain/vectorstores/base";
import * as vectorstores_memory from "langchain/vectorstores/memory";
Expand Down
1 change: 1 addition & 0 deletions test-exports-esm/src/entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import * as embeddings_fake from "langchain/embeddings/fake";
import * as embeddings_openai from "langchain/embeddings/openai";
import * as llms_base from "langchain/llms/base";
import * as llms_openai from "langchain/llms/openai";
import * as llms_ai21 from "langchain/llms/ai21";
import * as prompts from "langchain/prompts";
import * as vectorstores_base from "langchain/vectorstores/base";
import * as vectorstores_memory from "langchain/vectorstores/memory";
Expand Down
1 change: 1 addition & 0 deletions test-exports-vercel/src/entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ export * from "langchain/embeddings/fake";
export * from "langchain/embeddings/openai";
export * from "langchain/llms/base";
export * from "langchain/llms/openai";
export * from "langchain/llms/ai21";
export * from "langchain/prompts";
export * from "langchain/vectorstores/base";
export * from "langchain/vectorstores/memory";
Expand Down
1 change: 1 addition & 0 deletions test-exports-vite/src/entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ export * from "langchain/embeddings/fake";
export * from "langchain/embeddings/openai";
export * from "langchain/llms/base";
export * from "langchain/llms/openai";
export * from "langchain/llms/ai21";
export * from "langchain/prompts";
export * from "langchain/vectorstores/base";
export * from "langchain/vectorstores/memory";
Expand Down

0 comments on commit c753c19

Please sign in to comment.