forked from langchain-ai/langchainjs
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
AI21 Labs Integration (langchain-ai#1505)
* AI21 Labs Integration * Yarn formatting for unit test * Fix entrypoint, polish, use AsyncCaller, add tests --------- Co-authored-by: jacoblee93 <[email protected]>
- Loading branch information
1 parent
e5356ac
commit c753c19
Showing
15 changed files
with
277 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import { AI21 } from "langchain/llms/ai21"; | ||
|
||
const model = new AI21({ | ||
ai21ApiKey: "YOUR_AI21_API_KEY", // Or set as process.env.AI21_API_KEY | ||
}); | ||
|
||
const res = await model.call(`Translate "I love programming" into German.`); | ||
|
||
console.log({ res }); | ||
|
||
/* | ||
{ | ||
res: "\nIch liebe das Programmieren." | ||
} | ||
*/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
import { LLM, BaseLLMParams } from "./base.js"; | ||
import { getEnvironmentVariable } from "../util/env.js"; | ||
|
||
export type AI21PenaltyData = { | ||
scale: number; | ||
applyToWhitespaces: boolean; | ||
applyToPunctuations: boolean; | ||
applyToNumbers: boolean; | ||
applyToStopwords: boolean; | ||
applyToEmojis: boolean; | ||
}; | ||
|
||
export interface AI21Input extends BaseLLMParams { | ||
ai21ApiKey?: string; | ||
model?: string; | ||
temperature?: number; | ||
minTokens?: number; | ||
maxTokens?: number; | ||
topP?: number; | ||
presencePenalty?: AI21PenaltyData; | ||
countPenalty?: AI21PenaltyData; | ||
frequencyPenalty?: AI21PenaltyData; | ||
numResults?: number; | ||
logitBias?: Record<string, number>; | ||
stop?: string[]; | ||
baseUrl?: string; | ||
} | ||
|
||
export class AI21 extends LLM implements AI21Input { | ||
model = "j2-jumbo-instruct"; | ||
|
||
temperature = 0.7; | ||
|
||
maxTokens = 1024; | ||
|
||
minTokens = 0; | ||
|
||
topP = 1; | ||
|
||
presencePenalty = AI21.getDefaultAI21PenaltyData(); | ||
|
||
countPenalty = AI21.getDefaultAI21PenaltyData(); | ||
|
||
frequencyPenalty = AI21.getDefaultAI21PenaltyData(); | ||
|
||
numResults = 1; | ||
|
||
logitBias?: Record<string, number>; | ||
|
||
ai21ApiKey?: string; | ||
|
||
stop?: string[]; | ||
|
||
baseUrl?: string; | ||
|
||
constructor(fields?: AI21Input) { | ||
super(fields ?? {}); | ||
|
||
this.model = fields?.model ?? this.model; | ||
this.temperature = fields?.temperature ?? this.temperature; | ||
this.maxTokens = fields?.maxTokens ?? this.maxTokens; | ||
this.minTokens = fields?.minTokens ?? this.minTokens; | ||
this.topP = fields?.topP ?? this.topP; | ||
this.presencePenalty = fields?.presencePenalty ?? this.presencePenalty; | ||
this.countPenalty = fields?.countPenalty ?? this.countPenalty; | ||
this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty; | ||
this.numResults = fields?.numResults ?? this.numResults; | ||
this.logitBias = fields?.logitBias; | ||
this.ai21ApiKey = | ||
fields?.ai21ApiKey ?? getEnvironmentVariable("AI21_API_KEY"); | ||
this.stop = fields?.stop; | ||
this.baseUrl = fields?.baseUrl; | ||
} | ||
|
||
validateEnvironment() { | ||
if (!this.ai21ApiKey) { | ||
throw new Error( | ||
`No AI21 API key found. Please set it as "AI21_API_KEY" in your environment variables.` | ||
); | ||
} | ||
} | ||
|
||
static getDefaultAI21PenaltyData(): AI21PenaltyData { | ||
return { | ||
scale: 0, | ||
applyToWhitespaces: true, | ||
applyToPunctuations: true, | ||
applyToNumbers: true, | ||
applyToStopwords: true, | ||
applyToEmojis: true, | ||
}; | ||
} | ||
|
||
/** Get the type of LLM. */ | ||
_llmType() { | ||
return "ai21"; | ||
} | ||
|
||
/** Get the default parameters for calling AI21 API. */ | ||
get defaultParams() { | ||
return { | ||
temperature: this.temperature, | ||
maxTokens: this.maxTokens, | ||
minTokens: this.minTokens, | ||
topP: this.topP, | ||
presencePenalty: this.presencePenalty, | ||
countPenalty: this.countPenalty, | ||
frequencyPenalty: this.frequencyPenalty, | ||
numResults: this.numResults, | ||
logitBias: this.logitBias, | ||
}; | ||
} | ||
|
||
/** Get the identifying parameters for this LLM. */ | ||
get identifyingParams() { | ||
return { ...this.defaultParams, model: this.model }; | ||
} | ||
|
||
/** Call out to AI21's complete endpoint. | ||
Args: | ||
prompt: The prompt to pass into the model. | ||
stop: Optional list of stop words to use when generating. | ||
Returns: | ||
The string generated by the model. | ||
Example: | ||
let response = ai21._call("Tell me a joke."); | ||
*/ | ||
async _call( | ||
prompt: string, | ||
options: this["ParsedCallOptions"] | ||
): Promise<string> { | ||
let stop = options?.stop; | ||
this.validateEnvironment(); | ||
if (this.stop && stop && this.stop.length > 0 && stop.length > 0) { | ||
throw new Error("`stop` found in both the input and default params."); | ||
} | ||
stop = this.stop ?? stop ?? []; | ||
|
||
const baseUrl = | ||
this.baseUrl ?? this.model === "j1-grande-instruct" | ||
? "https://api.ai21.com/studio/v1/experimental" | ||
: "https://api.ai21.com/studio/v1"; | ||
|
||
const url = `${baseUrl}/${this.model}/complete`; | ||
const headers = { | ||
Authorization: `Bearer ${this.ai21ApiKey}`, | ||
"Content-Type": "application/json", | ||
}; | ||
const data = { prompt, stopSequences: stop, ...this.defaultParams }; | ||
const responseData = await this.caller.callWithOptions( | ||
{ signal: options.signal }, | ||
async () => { | ||
const response = await fetch(url, { | ||
method: "POST", | ||
headers, | ||
body: JSON.stringify(data), | ||
signal: options.signal, | ||
}); | ||
if (!response.ok) { | ||
const error = new Error( | ||
`AI21 call failed with status code ${response.status}` | ||
); | ||
// eslint-disable-next-line @typescript-eslint/no-explicit-any | ||
(error as any).response = response; | ||
throw error; | ||
} | ||
return response.json(); | ||
} | ||
); | ||
|
||
if ( | ||
!responseData.completions || | ||
responseData.completions.length === 0 || | ||
!responseData.completions[0].data | ||
) { | ||
throw new Error("No completions found in response"); | ||
} | ||
|
||
return responseData.completions[0].data.text ?? ""; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import { test, describe, expect } from "@jest/globals"; | ||
import { AI21 } from "../ai21.js"; | ||
|
||
describe("AI21", () => { | ||
test("test call", async () => { | ||
const ai21 = new AI21({}); | ||
const result = await ai21.call( | ||
"What is a good name for a company that makes colorful socks?" | ||
); | ||
console.log({ result }); | ||
}); | ||
|
||
test("test translation call", async () => { | ||
const ai21 = new AI21({}); | ||
const result = await ai21.call( | ||
`Translate "I love programming" into German.` | ||
); | ||
console.log({ result }); | ||
}); | ||
|
||
test("test JSON output call", async () => { | ||
const ai21 = new AI21({}); | ||
const result = await ai21.call( | ||
`Output a JSON object with three string fields: "name", "birthplace", "bio".` | ||
); | ||
console.log({ result }); | ||
}); | ||
|
||
test("should abort the request", async () => { | ||
const ai21 = new AI21({}); | ||
const controller = new AbortController(); | ||
|
||
await expect(() => { | ||
const ret = ai21.call("Respond with an extremely verbose response", { | ||
signal: controller.signal, | ||
}); | ||
controller.abort(); | ||
return ret; | ||
}).rejects.toThrow("AbortError: This operation was aborted"); | ||
}); | ||
|
||
test("throws an error when response status is not ok", async () => { | ||
const ai21 = new AI21({ | ||
ai21ApiKey: "BAD_KEY", | ||
}); | ||
|
||
await expect(ai21.call("Test prompt")).rejects.toThrow( | ||
"AI21 call failed with status code 401" | ||
); | ||
}); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters