Skip to content

Commit

Permalink
Add vectorstore agent (langchain-ai#175)
Browse files Browse the repository at this point in the history
* add vectorstore tool

* vector store agents

* fix exports

* cr

* cr

---------

Co-authored-by: Harrison Chase <[email protected]>
  • Loading branch information
agola11 and hwchase17 authored Mar 1, 2023
1 parent 6148be9 commit e0a3347
Show file tree
Hide file tree
Showing 9 changed files with 220 additions and 44 deletions.
2 changes: 1 addition & 1 deletion examples/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@
"tsx": "^3.12.3",
"typescript": "^4.9.5"
}
}
}
43 changes: 0 additions & 43 deletions examples/src/agents/agents_vectorstore.ts

This file was deleted.

44 changes: 44 additions & 0 deletions examples/src/agents/vectorstore.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import { OpenAI } from "langchain";
import { HNSWLib } from "langchain/vectorstores";
import { OpenAIEmbeddings } from "langchain/embeddings";
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter";
import * as fs from "fs";
import {
VectorStoreToolkit,
createVectorStoreAgent,
VectorStoreInfo,
} from "langchain/agents";

export const run = async () => {
const model = new OpenAI({ temperature: 0 });
/* Load in the file we want to do question answering over */
const text = fs.readFileSync("state_of_the_union.txt", "utf8");
/* Split the text into chunks */
const textSplitter = new RecursiveCharacterTextSplitter({ chunkSize: 1000 });
const docs = await textSplitter.createDocuments([text]);
/* Create the vectorstore */
const vectorStore = await HNSWLib.fromDocuments(docs, new OpenAIEmbeddings());

/* Create the agent */
const vectorStoreInfo: VectorStoreInfo = {
name: "state_of_union_address",
description: "the most recent state of the Union address",
vectorStore,
};

const toolkit = new VectorStoreToolkit(vectorStoreInfo, model);
const agent = createVectorStoreAgent(model, toolkit);

const input =
"What did biden say about Ketanji Brown Jackson is the state of the union address?";
console.log(`Executing: ${input}`);
const result = await agent.call({ input });
console.log(`Got output ${result.output}`);
console.log(
`Got intermediate steps ${JSON.stringify(
result.intermediateSteps,
null,
2
)}`
);
};
7 changes: 7 additions & 0 deletions langchain/src/agents/agent_toolkits/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,10 @@ export {
OpenApiToolkit,
createOpenApiAgent,
} from "./openapi/openapi.js";
export {
VectorStoreInfo,
VectorStoreToolkit,
VectorStoreRouterToolkit,
createVectorStoreAgent,
createVectorStoreRouterAgent,
} from "./vectorstore/vectorstore.js";
9 changes: 9 additions & 0 deletions langchain/src/agents/agent_toolkits/vectorstore/prompt.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
export const VECTOR_PREFIX = `You are an agent designed to answer questions about sets of documents.
You have access to tools for interacting with the documents, and the inputs to the tools are questions.
Sometimes, you will be asked to provide sources for your questions, in which case you should use the appropriate tool to do so.
If the question does not seem relevant to any of the tools provided, just return "I don't know" as the answer.`;

export const VECTOR_ROUTER_PREFIX = `You are an agent designed to answer questions.
You have access to tools for interacting with different sources, and the inputs to the tools are questions.
Your main task is to decide which of the tools is relevant for answering question at hand.
For complex questions, you can break the question down into sub questions and use tools to answers the sub questions.`;
116 changes: 116 additions & 0 deletions langchain/src/agents/agent_toolkits/vectorstore/vectorstore.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import { Tool, VectorStoreQATool } from "../../tools/index.js";
import { VectorStore } from "../../../vectorstores/index.js";
import { Toolkit } from "../base.js";
import { BaseLLM } from "../../../llms/index.js";
import { CreatePromptArgs, ZeroShotAgent } from "../../mrkl/index.js";
import { VECTOR_PREFIX, VECTOR_ROUTER_PREFIX } from "./prompt.js";
import { SUFFIX } from "../../mrkl/prompt.js";
import { LLMChain } from "../../../chains/index.js";
import { AgentExecutor } from "../../executor.js";

export interface VectorStoreInfo {
vectorStore: VectorStore;
name: string;
description: string;
}

export class VectorStoreToolkit extends Toolkit {
tools: Tool[];

llm: BaseLLM;

constructor(vectorStoreInfo: VectorStoreInfo, llm: BaseLLM) {
super();
const description = VectorStoreQATool.getDescription(
vectorStoreInfo.name,
vectorStoreInfo.description
);
this.llm = llm;
this.tools = [
new VectorStoreQATool(vectorStoreInfo.name, description, {
vectorStore: vectorStoreInfo.vectorStore,
llm: this.llm,
}),
];
}
}

export class VectorStoreRouterToolkit extends Toolkit {
tools: Tool[];

vectorStoreInfos: VectorStoreInfo[];

llm: BaseLLM;

constructor(vectorStoreInfos: VectorStoreInfo[], llm: BaseLLM) {
super();
this.llm = llm;
this.vectorStoreInfos = vectorStoreInfos;
this.tools = vectorStoreInfos.map((vectorStoreInfo) => {
const description = VectorStoreQATool.getDescription(
vectorStoreInfo.name,
vectorStoreInfo.description
);
return new VectorStoreQATool(vectorStoreInfo.name, description, {
vectorStore: vectorStoreInfo.vectorStore,
llm: this.llm,
});
});
}
}

export function createVectorStoreAgent(
llm: BaseLLM,
toolkit: VectorStoreToolkit,
args?: CreatePromptArgs
) {
const {
prefix = VECTOR_PREFIX,
suffix = SUFFIX,
inputVariables = ["input", "agent_scratchpad"],
} = args ?? {};
const { tools } = toolkit;
const prompt = ZeroShotAgent.createPrompt(tools, {
prefix,
suffix,
inputVariables,
});
const chain = new LLMChain({ prompt, llm });
const agent = new ZeroShotAgent({
llmChain: chain,
allowedTools: tools.map((t) => t.name),
});
return AgentExecutor.fromAgentAndTools({
agent,
tools,
returnIntermediateSteps: true,
});
}

export function createVectorStoreRouterAgent(
llm: BaseLLM,
toolkit: VectorStoreRouterToolkit,
args?: CreatePromptArgs
) {
const {
prefix = VECTOR_ROUTER_PREFIX,
suffix = SUFFIX,
inputVariables = ["input", "agent_scratchpad"],
} = args ?? {};
const { tools } = toolkit;
const prompt = ZeroShotAgent.createPrompt(tools, {
prefix,
suffix,
inputVariables,
});
const chain = new LLMChain({ prompt, llm });
const agent = new ZeroShotAgent({
llmChain: chain,
allowedTools: tools.map((t) => t.name),
});
return AgentExecutor.fromAgentAndTools({
agent,
tools,
returnIntermediateSteps: true,
});
}
4 changes: 4 additions & 0 deletions langchain/src/agents/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ export {
JsonToolkit,
RequestsToolkit,
OpenApiToolkit,
VectorStoreInfo,
VectorStoreToolkit,
VectorStoreRouterToolkit,
createSqlAgent,
createJsonAgent,
createOpenApiAgent,
createVectorStoreAgent,
} from "./agent_toolkits/index.js";
1 change: 1 addition & 0 deletions langchain/src/agents/tools/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ export {
Json,
} from "./json.js";
export { RequestsGetTool, RequestsPostTool } from "./requests.js";
export { VectorStoreQATool } from "./vectorstore.js";
38 changes: 38 additions & 0 deletions langchain/src/agents/tools/vectorstore.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import { VectorStore } from "../../vectorstores/index.js";
import { BaseLLM } from "../../llms/index.js";
import { VectorDBQAChain } from "../../chains/index.js";
import { Tool } from "./base.js";

interface VectorStoreTool {
vectorStore: VectorStore;
llm: BaseLLM;
}

export class VectorStoreQATool extends Tool implements VectorStoreTool {
vectorStore: VectorStore;

llm: BaseLLM;

name: string;

description: string;

chain: VectorDBQAChain;

constructor(name: string, description: string, fields: VectorStoreTool) {
super();
this.name = name;
this.description = description;
this.vectorStore = fields.vectorStore;
this.llm = fields.llm;
this.chain = VectorDBQAChain.fromLLM(this.llm, this.vectorStore);
}

static getDescription(name: string, description: string): string {
return `Useful for when you need to answer questions about ${name}. Whenever you need information about ${description} you should ALWAYS use this. Input should be a fully formed question.`;
}

async call(input: string) {
return this.chain.run(input);
}
}

0 comments on commit e0a3347

Please sign in to comment.