Skip to content

Commit

Permalink
BaseLLM -> Base language (langchain-ai#327)
Browse files Browse the repository at this point in the history
* feat: expose langchain/base_language

* another chain

* fake it til its upstream

* crappy test

* agents

* more

* test for ChatOpenAI

* fix base_language deserialize

* format
  • Loading branch information
jacobrosenthal authored Mar 20, 2023
1 parent 1f32d60 commit 3612ae7
Show file tree
Hide file tree
Showing 23 changed files with 132 additions and 69 deletions.
2 changes: 1 addition & 1 deletion docs/docs/modules/chains/chat_vector_db_qa.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ In this code snippet, the fromLLM method of the ChatVectorDBQAChain class has th

```typescript
static fromLLM(
llm: BaseLLM,
llm: BaseLanguageModel,
vectorstore: VectorStore,
options?: {
questionGeneratorTemplate?: string;
Expand Down
2 changes: 2 additions & 0 deletions langchain/.gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
agents.js
agents.d.ts
base_language.js
base_language.d.ts
tools.js
tools.d.ts
chains.js
Expand Down
1 change: 1 addition & 0 deletions langchain/create-entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import fs from "fs";

const entrypoints = {
agents: "agents/index",
base_language: "base_language/index",
tools: "agents/tools/index",
chains: "chains/index",
embeddings: "embeddings/index",
Expand Down
6 changes: 6 additions & 0 deletions langchain/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
"dist/",
"agents.js",
"agents.d.ts",
"base_language.js",
"base_language.d.ts",
"tools.js",
"tools.d.ts",
"chains.js",
Expand Down Expand Up @@ -220,6 +222,10 @@
"types": "./agents.d.ts",
"import": "./agents.js"
},
"./base_language": {
"types": "./base_language.d.ts",
"import": "./base_language.js"
},
"./tools": {
"types": "./tools.d.ts",
"import": "./tools.js"
Expand Down
6 changes: 3 additions & 3 deletions langchain/src/agents/agent.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { BaseLLM } from "../llms/index.js";
import { BaseLanguageModel } from "../base_language/index.js";
import { LLMChain } from "../chains/llm_chain.js";
import { BasePromptTemplate } from "../prompts/index.js";
import {
Expand Down Expand Up @@ -98,7 +98,7 @@ export abstract class Agent {

/** Construct an agent from an LLM and a list of tools */
static fromLLMAndTools(
_llm: BaseLLM,
_llm: BaseLanguageModel,
_tools: Tool[],
// eslint-disable-next-line @typescript-eslint/no-explicit-any
_args?: Record<string, any>
Expand Down Expand Up @@ -226,7 +226,7 @@ export abstract class Agent {
* Load an agent from a json-like object describing it.
*/
static async deserialize(
data: SerializedAgent & { llm?: BaseLLM; tools?: Tool[] }
data: SerializedAgent & { llm?: BaseLanguageModel; tools?: Tool[] }
): Promise<Agent> {
switch (data._type) {
case "zero-shot-react-description": {
Expand Down
4 changes: 2 additions & 2 deletions langchain/src/agents/agent_toolkits/json/json.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { BaseLLM } from "../../../llms/index.js";
import { BaseLanguageModel } from "../../../base_language/index.js";
import {
JsonGetValueTool,
JsonListKeysTool,
Expand All @@ -24,7 +24,7 @@ export class JsonToolkit extends Toolkit {
}

export function createJsonAgent(
llm: BaseLLM,
llm: BaseLanguageModel,
toolkit: JsonToolkit,
args?: CreatePromptArgs
) {
Expand Down
6 changes: 3 additions & 3 deletions langchain/src/agents/agent_toolkits/openapi/openapi.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { BaseLLM } from "../../../llms/index.js";
import { BaseLanguageModel } from "../../../base_language/index.js";
import {
DynamicTool,
JsonSpec,
Expand Down Expand Up @@ -28,7 +28,7 @@ export class RequestsToolkit extends Toolkit {
}

export class OpenApiToolkit extends RequestsToolkit {
constructor(jsonSpec: JsonSpec, llm: BaseLLM, headers?: Headers) {
constructor(jsonSpec: JsonSpec, llm: BaseLanguageModel, headers?: Headers) {
super(headers);
const jsonAgent = createJsonAgent(llm, new JsonToolkit(jsonSpec));
this.tools = [
Expand All @@ -46,7 +46,7 @@ export class OpenApiToolkit extends RequestsToolkit {
}

export function createOpenApiAgent(
llm: BaseLLM,
llm: BaseLanguageModel,
openApiToolkit: OpenApiToolkit,
args?: CreatePromptArgs
) {
Expand Down
4 changes: 2 additions & 2 deletions langchain/src/agents/agent_toolkits/sql/sql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {
QuerySqlTool,
} from "../../tools/index.js";
import { Toolkit } from "../base.js";
import { BaseLLM } from "../../../llms/index.js";
import { BaseLanguageModel } from "../../../base_language/index.js";
import { SQL_PREFIX, SQL_SUFFIX } from "./prompt.js";
import { renderTemplate } from "../../../prompts/template.js";
import { LLMChain } from "../../../chains/index.js";
Expand Down Expand Up @@ -39,7 +39,7 @@ export class SqlToolkit extends Toolkit {
}

export function createSqlAgent(
llm: BaseLLM,
llm: BaseLanguageModel,
toolkit: SqlToolkit,
args?: SqlCreatePromptArgs
) {
Expand Down
14 changes: 7 additions & 7 deletions langchain/src/agents/agent_toolkits/vectorstore/vectorstore.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { Tool, VectorStoreQATool } from "../../tools/index.js";
import { VectorStore } from "../../../vectorstores/index.js";
import { Toolkit } from "../base.js";
import { BaseLLM } from "../../../llms/index.js";
import { BaseLanguageModel } from "../../../base_language/index.js";
import { CreatePromptArgs, ZeroShotAgent } from "../../mrkl/index.js";
import { VECTOR_PREFIX, VECTOR_ROUTER_PREFIX } from "./prompt.js";
import { SUFFIX } from "../../mrkl/prompt.js";
Expand All @@ -17,9 +17,9 @@ export interface VectorStoreInfo {
export class VectorStoreToolkit extends Toolkit {
tools: Tool[];

llm: BaseLLM;
llm: BaseLanguageModel;

constructor(vectorStoreInfo: VectorStoreInfo, llm: BaseLLM) {
constructor(vectorStoreInfo: VectorStoreInfo, llm: BaseLanguageModel) {
super();
const description = VectorStoreQATool.getDescription(
vectorStoreInfo.name,
Expand All @@ -40,9 +40,9 @@ export class VectorStoreRouterToolkit extends Toolkit {

vectorStoreInfos: VectorStoreInfo[];

llm: BaseLLM;
llm: BaseLanguageModel;

constructor(vectorStoreInfos: VectorStoreInfo[], llm: BaseLLM) {
constructor(vectorStoreInfos: VectorStoreInfo[], llm: BaseLanguageModel) {
super();
this.llm = llm;
this.vectorStoreInfos = vectorStoreInfos;
Expand All @@ -60,7 +60,7 @@ export class VectorStoreRouterToolkit extends Toolkit {
}

export function createVectorStoreAgent(
llm: BaseLLM,
llm: BaseLanguageModel,
toolkit: VectorStoreToolkit,
args?: CreatePromptArgs
) {
Expand Down Expand Up @@ -88,7 +88,7 @@ export function createVectorStoreAgent(
}

export function createVectorStoreRouterAgent(
llm: BaseLLM,
llm: BaseLanguageModel,
toolkit: VectorStoreRouterToolkit,
args?: CreatePromptArgs
) {
Expand Down
6 changes: 3 additions & 3 deletions langchain/src/agents/load.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
import { Agent } from "./agent.js";
import { Tool } from "./tools/base.js";
import { BaseLLM } from "../llms/index.js";
import { BaseLanguageModel } from "../base_language/index.js";
import { loadFromHub } from "../util/hub.js";
import { FileLoader, loadFromFile, parseFileConfig } from "../util/index.js";

const loadAgentFromFile: FileLoader<Agent> = async (
file: string,
path: string,
llmAndTools?: { llm?: BaseLLM; tools?: Tool[] }
llmAndTools?: { llm?: BaseLanguageModel; tools?: Tool[] }
) => {
const serialized = parseFileConfig(file, path);
return Agent.deserialize({ ...serialized, ...llmAndTools });
};

export const loadAgent = async (
uri: string,
llmAndTools?: { llm?: BaseLLM; tools?: Tool[] }
llmAndTools?: { llm?: BaseLanguageModel; tools?: Tool[] }
): Promise<Agent> => {
const hubResult = await loadFromHub(
uri,
Expand Down
6 changes: 3 additions & 3 deletions langchain/src/agents/tools/vectorstore.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import { VectorStore } from "../../vectorstores/index.js";
import { BaseLLM } from "../../llms/index.js";
import { BaseLanguageModel } from "../../base_language/index.js";
import { VectorDBQAChain } from "../../chains/index.js";
import { Tool } from "./base.js";

interface VectorStoreTool {
vectorStore: VectorStore;
llm: BaseLLM;
llm: BaseLanguageModel;
}

export class VectorStoreQATool extends Tool implements VectorStoreTool {
vectorStore: VectorStore;

llm: BaseLLM;
llm: BaseLanguageModel;

name: string;

Expand Down
44 changes: 44 additions & 0 deletions langchain/src/base_language/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ import { CallbackManager, getCallbackManager } from "../callbacks/index.js";

const getVerbosity = () => false;

export type SerializedLLM = {
_model: string;
_type: string;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} & Record<string, any>;

/**
* Base interface for language model parameters.
* A subclass of {@link BaseLanguageModel} should have a constructor that
Expand Down Expand Up @@ -36,5 +42,43 @@ export abstract class BaseLanguageModel implements BaseLanguageModelParams {

abstract _modelType(): string;

abstract _llmType(): string;

abstract getNumTokens(text: string): number;

/**
* Get the identifying parameters of the LLM.
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
_identifyingParams(): Record<string, any> {
return {};
}

/**
* Return a json-like object representing this LLM.
*/
serialize(): SerializedLLM {
return {
...this._identifyingParams(),
_type: this._llmType(),
_model: this._modelType(),
};
}

/**
* Load an LLM from a json-like object describing it.
*/
static async deserialize(data: SerializedLLM): Promise<BaseLanguageModel> {
const { _type, _model, ...rest } = data;
if (_model && _model !== "base_chat_model") {
throw new Error(`Cannot load LLM with model ${_model}`);
}
const Cls = {
openai: (await import("../chat_models/openai.js")).ChatOpenAI,
}[_type];
if (Cls === undefined) {
throw new Error(`Cannot load LLM with type ${_type}`);
}
return new Cls(rest);
}
}
4 changes: 2 additions & 2 deletions langchain/src/chains/chat_vector_db_chain.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { PromptTemplate } from "../prompts/index.js";
import { BaseLLM } from "../llms/index.js";
import { BaseLanguageModel } from "../base_language/index.js";
import { VectorStore } from "../vectorstores/base.js";
import {
SerializedBaseChain,
Expand Down Expand Up @@ -167,7 +167,7 @@ export class ChatVectorDBQAChain
}

static fromLLM(
llm: BaseLLM,
llm: BaseLanguageModel,
vectorstore: VectorStore,
options: {
inputKey?: string;
Expand Down
7 changes: 4 additions & 3 deletions langchain/src/chains/llm_chain.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { BaseChain, ChainInputs } from "./base.js";
import { BaseLLM, SerializedLLM } from "../llms/index.js";
import { SerializedLLM } from "../llms/index.js";

import { BaseMemory, BufferMemory } from "../memory/index.js";
import {
BasePromptTemplate,
Expand Down Expand Up @@ -97,15 +98,15 @@ export class LLMChain extends BaseChain implements LLMChainInput {
>("prompt", data);

return new LLMChain({
llm: await BaseLLM.deserialize(serializedLLM),
llm: await BaseLanguageModel.deserialize(serializedLLM),
prompt: await BasePromptTemplate.deserialize(serializedPrompt),
});
}

serialize(): SerializedLLMChain {
return {
_type: this._chainType(),
// llm: this.llm.serialize(), TODO fix this now that llm is BaseLanguageModel
llm: this.llm.serialize(),
prompt: this.prompt.serialize(),
};
}
Expand Down
5 changes: 2 additions & 3 deletions langchain/src/chains/prompt_selector.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import { BaseChatModel } from "../chat_models/base.js";
import { BaseLLM } from "../llms/base.js";
import { BasePromptTemplate } from "../prompts/base.js";
import { BaseLanguageModel } from "../base_language/index.js";

Expand Down Expand Up @@ -38,8 +37,8 @@ export class ConditionalPromptSelector extends BasePromptSelector {
}
}

export function isLLM(llm: BaseLanguageModel): llm is BaseLLM {
return llm instanceof BaseLLM;
export function isLLM(llm: BaseLanguageModel): llm is BaseLanguageModel {
return llm instanceof BaseLanguageModel;
}

export function isChatModel(llm: BaseLanguageModel): llm is BaseChatModel {
Expand Down
11 changes: 7 additions & 4 deletions langchain/src/chains/question_answering/load.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import { BaseLLM } from "../../llms/index.js";
import { LLMChain } from "../llm_chain.js";
import { PromptTemplate } from "../../prompts/index.js";
import {
Expand All @@ -12,14 +11,18 @@ import {
COMBINE_PROMPT_SELECTOR,
COMBINE_QA_PROMPT_SELECTOR,
} from "./map_reduce_prompts.js";
import { BaseLanguageModel } from "../../base_language/index.js";

interface qaChainParams {
prompt?: PromptTemplate;
combineMapPrompt?: PromptTemplate;
combinePrompt?: PromptTemplate;
type?: string;
}
export const loadQAChain = (llm: BaseLLM, params: qaChainParams = {}) => {
export const loadQAChain = (
llm: BaseLanguageModel,
params: qaChainParams = {}
) => {
const {
prompt = DEFAULT_QA_PROMPT,
combineMapPrompt = DEFAULT_COMBINE_QA_PROMPT,
Expand Down Expand Up @@ -52,7 +55,7 @@ interface StuffQAChainParams {
}

export const loadQAStuffChain = (
llm: BaseLLM,
llm: BaseLanguageModel,
params: StuffQAChainParams = {}
) => {
const { prompt = QA_PROMPT_SELECTOR.getPrompt(llm) } = params;
Expand All @@ -67,7 +70,7 @@ interface MapReduceQAChainParams {
}

export const loadQAMapReduceChain = (
llm: BaseLLM,
llm: BaseLanguageModel,
params: MapReduceQAChainParams = {}
) => {
const {
Expand Down
Loading

0 comments on commit 3612ae7

Please sign in to comment.