diff --git a/docs/docs/modules/agents/how_to/structured_output_runnables_agent.mdx b/docs/docs/modules/agents/how_to/structured_output_runnables_agent.mdx new file mode 100644 index 000000000000..edbc995eb790 --- /dev/null +++ b/docs/docs/modules/agents/how_to/structured_output_runnables_agent.mdx @@ -0,0 +1,185 @@ +# Structured Output Agent with Runnables + +The `AgentExecutor` class accepts a `Runnable` as the agent. With this, we can create powerful agents using the LCEL and the `AgentExecutor` class. + +Here is a simple example of an agent which uses `Runnables`, a retriever and a structured output parser to create an OpenAI functions agent that finds specific information in a large text document. + +The first step is to import necessary modules + +```typescript +import { zodToJsonSchema } from "zod-to-json-schema"; +import fs from "fs"; +import { z } from "zod"; +import type { + AIMessage, + AgentAction, + AgentFinish, + AgentStep, +} from "langchain/schema/index.js"; +import { RunnableSequence } from "langchain/schema/runnable/base.js"; +import { + ChatPromptTemplate, + MessagesPlaceholder, +} from "langchain/prompts/chat.js"; +import { ChatOpenAI } from "langchain/chat_models/openai.js"; +import { createRetrieverTool } from "langchain/agents/toolkits/index.js"; +import { RecursiveCharacterTextSplitter } from "langchain/text_splitter.js"; +import { HNSWLib } from "langchain/vectorstores/hnswlib.js"; +import { OpenAIEmbeddings } from "langchain/embeddings/openai.js"; +import { formatToOpenAIFunction } from "langchain/tools/convert_to_openai.js"; +import { AgentExecutor } from "langchain/agents/executor.js"; +import { formatForOpenAIFunctions } from "langchain/agents/format_scratchpad.js"; +``` + +Next, we load the text document and embed it using the OpenAI embeddings model. + +```typescript +// Read text file & embed documents +const text = fs.readFileSync("examples/state_of_the_union.txt", "utf8"); +const textSplitter = new RecursiveCharacterTextSplitter({ chunkSize: 1000 }); +let docs = await textSplitter.createDocuments([text]); +// Add fake document source information to the metadata +docs = docs.map((doc, i) => ({ + ...doc, + metadata: { + page_chunk: i, + }, +})); +// Initialize docs & create retriever +const vectorStore = await HNSWLib.fromDocuments(docs, new OpenAIEmbeddings()); +``` + +Since we're going to want to retrieve the embeddings inside the agent, we need to instantiate the vector store as a retriever. +We also need an LLM to preform the calls with. + +```typescript +const retriever = vectorStore.asRetriever(); +const llm = new ChatOpenAI({}); +``` + +In order to use our retriever with the LLM as an OpenAI function, we need to convert the retriever to a tool + +```typescript +const retrieverTool = createRetrieverTool(retriever, { + name: "state-of-union-retriever", + description: + "Query a retriever to get information about state of the union address", +}); +``` + +Now we can define our prompt template. We'll use a simple `ChatPromptTemplate` with placeholders for the user's question, and the agent scratchpad (this will be very helpful in the future). + +```typescript +const prompt = ChatPromptTemplate.fromMessages([ + ["system", "You are a helpful assistant"], + new MessagesPlaceholder("agent_scratchpad"), + ["user", "{input}"], +]); +``` + +After that, we define our structured response schema using zod. This schema defines the structure of the final response from the agent. + +```typescript +const responseSchema = z.object({ + answer: z.string().describe("The final answer to respond to the user"), + sources: z + .array(z.string()) + .describe( + "List of page chunks that contain answer to the question. Only include a page chunk if it contains relevant information" + ), +}); +``` + +Once our response schema is defined, we can construct it as an OpenAI function to later be passed to the model. +This is an important step regarding consistency as the model will always respond in this schema when it successfully completes a task + +```typescript +const responseOpenAIFunction = { + name: "response", + description: "Return the response to the user", + parameters: zodToJsonSchema(responseSchema), +}; +``` + +Next, we can construct the custom structured output parser. + +```typescript +const structuredOutputParser = ( + output: AIMessage +): AgentAction | AgentFinish => { + // If no function call is passed, return the output as an instance of `AgentFinish` + if (!("function_call" in output.additional_kwargs)) { + return { returnValues: { output: output.content }, log: output.content }; + } + // Extract the function call name and arguments + const functionCall = output.additional_kwargs.function_call; + const name = functionCall?.name as string; + const inputs = functionCall?.arguments as string; + // Parse the arguments as JSON + const jsonInput = JSON.parse(inputs); + // If the function call name is `response` then we know it's used our final + // response function and can return an instance of `AgentFinish` + if (name === "response") { + return { returnValues: { ...jsonInput }, log: output.content }; + } + // If none of the above are true, the agent is not yet finished and we return + // an instance of `AgentAction` + return { + tool: name, + toolInput: jsonInput, + log: output.content, + }; +}; +``` + +After this, we can bind our two functions to the LLM, and create a runnable sequence which will be used as the agent. + +**Important** - note here we pass in `agent_scratchpad` as an input variable, which formats all the previous steps using the `formatForOpenAIFunctions` function. +This is very important as it contains all the context history the model needs to preform accurate tasks. Without this, the model would have no context on the previous steps taken. +The `formatForOpenAIFunctions` function returns the steps as an array of `BaseMessage`. This is necessary as the `MessagesPlaceholder` class expects this type as the input. + +```typescript +const llmWithTools = llm.bind({ + functions: [formatToOpenAIFunction(retrieverTool), responseOpenAIFunction], +}); +/** Create the runnable */ +const runnableAgent = RunnableSequence.from([ + { + input: (i: { input: string }) => i.input, + agent_scratchpad: (i: { input: string; steps: Array }) => + formatForOpenAIFunctions(i.steps), + }, + prompt, + llmWithTools, + structuredOutputParser, +]); +``` + +Finally, we can create an instance of `AgentExecutor` and run the agent. + +```typescript +const executor = AgentExecutor.fromAgentAndTools({ + agent: runnableAgent, + tools: [retrieverTool], +}); +/** Call invoke on the agent */ +const res = await executor.invoke({ + input: "what did the president say about kentaji brown jackson", +}); +console.log({ + res, +}); +``` + +The output will look something like this + +```typescript +{ + res: { + answer: 'President mentioned that he nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. He described her as one of our nation’s top legal minds and stated that she will continue Justice Breyer’s legacy of excellence.', + sources: [ + 'And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence. A former top litigator in private practice. A former federal public defender. And from a family of public school educators and police officers. A consensus builder. Since she’s been nominated, she’s received a broad range of support—from the Fraternal Order of Police to former judges appointed by Democrats and Republicans.' + ] + } +} +``` diff --git a/environment_tests/test-exports-bun/src/entrypoints.js b/environment_tests/test-exports-bun/src/entrypoints.js index 9c36721f05a5..3913afeec7c2 100644 --- a/environment_tests/test-exports-bun/src/entrypoints.js +++ b/environment_tests/test-exports-bun/src/entrypoints.js @@ -2,6 +2,7 @@ export * from "langchain/load"; export * from "langchain/load/serializable"; export * from "langchain/agents"; export * from "langchain/agents/toolkits"; +export * from "langchain/agents/format_scratchpad"; export * from "langchain/base_language"; export * from "langchain/tools"; export * from "langchain/chains"; diff --git a/environment_tests/test-exports-cf/src/entrypoints.js b/environment_tests/test-exports-cf/src/entrypoints.js index 9c36721f05a5..3913afeec7c2 100644 --- a/environment_tests/test-exports-cf/src/entrypoints.js +++ b/environment_tests/test-exports-cf/src/entrypoints.js @@ -2,6 +2,7 @@ export * from "langchain/load"; export * from "langchain/load/serializable"; export * from "langchain/agents"; export * from "langchain/agents/toolkits"; +export * from "langchain/agents/format_scratchpad"; export * from "langchain/base_language"; export * from "langchain/tools"; export * from "langchain/chains"; diff --git a/environment_tests/test-exports-cjs/src/entrypoints.js b/environment_tests/test-exports-cjs/src/entrypoints.js index 41cacf76314b..d185b8576273 100644 --- a/environment_tests/test-exports-cjs/src/entrypoints.js +++ b/environment_tests/test-exports-cjs/src/entrypoints.js @@ -2,6 +2,7 @@ const load = require("langchain/load"); const load_serializable = require("langchain/load/serializable"); const agents = require("langchain/agents"); const agents_toolkits = require("langchain/agents/toolkits"); +const agents_format_scratchpad = require("langchain/agents/format_scratchpad"); const base_language = require("langchain/base_language"); const tools = require("langchain/tools"); const chains = require("langchain/chains"); diff --git a/environment_tests/test-exports-esbuild/src/entrypoints.js b/environment_tests/test-exports-esbuild/src/entrypoints.js index cde0f3318c55..ca15ae428d33 100644 --- a/environment_tests/test-exports-esbuild/src/entrypoints.js +++ b/environment_tests/test-exports-esbuild/src/entrypoints.js @@ -2,6 +2,7 @@ import * as load from "langchain/load"; import * as load_serializable from "langchain/load/serializable"; import * as agents from "langchain/agents"; import * as agents_toolkits from "langchain/agents/toolkits"; +import * as agents_format_scratchpad from "langchain/agents/format_scratchpad"; import * as base_language from "langchain/base_language"; import * as tools from "langchain/tools"; import * as chains from "langchain/chains"; diff --git a/environment_tests/test-exports-esm/src/entrypoints.js b/environment_tests/test-exports-esm/src/entrypoints.js index cde0f3318c55..ca15ae428d33 100644 --- a/environment_tests/test-exports-esm/src/entrypoints.js +++ b/environment_tests/test-exports-esm/src/entrypoints.js @@ -2,6 +2,7 @@ import * as load from "langchain/load"; import * as load_serializable from "langchain/load/serializable"; import * as agents from "langchain/agents"; import * as agents_toolkits from "langchain/agents/toolkits"; +import * as agents_format_scratchpad from "langchain/agents/format_scratchpad"; import * as base_language from "langchain/base_language"; import * as tools from "langchain/tools"; import * as chains from "langchain/chains"; diff --git a/environment_tests/test-exports-vercel/src/entrypoints.js b/environment_tests/test-exports-vercel/src/entrypoints.js index 9c36721f05a5..3913afeec7c2 100644 --- a/environment_tests/test-exports-vercel/src/entrypoints.js +++ b/environment_tests/test-exports-vercel/src/entrypoints.js @@ -2,6 +2,7 @@ export * from "langchain/load"; export * from "langchain/load/serializable"; export * from "langchain/agents"; export * from "langchain/agents/toolkits"; +export * from "langchain/agents/format_scratchpad"; export * from "langchain/base_language"; export * from "langchain/tools"; export * from "langchain/chains"; diff --git a/environment_tests/test-exports-vite/src/entrypoints.js b/environment_tests/test-exports-vite/src/entrypoints.js index 9c36721f05a5..3913afeec7c2 100644 --- a/environment_tests/test-exports-vite/src/entrypoints.js +++ b/environment_tests/test-exports-vite/src/entrypoints.js @@ -2,6 +2,7 @@ export * from "langchain/load"; export * from "langchain/load/serializable"; export * from "langchain/agents"; export * from "langchain/agents/toolkits"; +export * from "langchain/agents/format_scratchpad"; export * from "langchain/base_language"; export * from "langchain/tools"; export * from "langchain/chains"; diff --git a/langchain/.gitignore b/langchain/.gitignore index 2f99b3435532..dea823522d73 100644 --- a/langchain/.gitignore +++ b/langchain/.gitignore @@ -19,6 +19,9 @@ agents/toolkits/aws_sfn.d.ts agents/toolkits/sql.cjs agents/toolkits/sql.js agents/toolkits/sql.d.ts +agents/format_scratchpad.cjs +agents/format_scratchpad.js +agents/format_scratchpad.d.ts base_language.cjs base_language.js base_language.d.ts diff --git a/langchain/package.json b/langchain/package.json index badbdc329255..9807d3b0cefd 100644 --- a/langchain/package.json +++ b/langchain/package.json @@ -31,6 +31,9 @@ "agents/toolkits/sql.cjs", "agents/toolkits/sql.js", "agents/toolkits/sql.d.ts", + "agents/format_scratchpad.cjs", + "agents/format_scratchpad.js", + "agents/format_scratchpad.d.ts", "base_language.cjs", "base_language.js", "base_language.d.ts", @@ -1337,6 +1340,11 @@ "import": "./agents/toolkits/sql.js", "require": "./agents/toolkits/sql.cjs" }, + "./agents/format_scratchpad": { + "types": "./agents/format_scratchpad.d.ts", + "import": "./agents/format_scratchpad.js", + "require": "./agents/format_scratchpad.cjs" + }, "./base_language": { "types": "./base_language.d.ts", "import": "./base_language.js", diff --git a/langchain/scripts/create-entrypoints.js b/langchain/scripts/create-entrypoints.js index 03c6da23ab3a..7319b5f1e588 100644 --- a/langchain/scripts/create-entrypoints.js +++ b/langchain/scripts/create-entrypoints.js @@ -16,6 +16,7 @@ const entrypoints = { "agents/toolkits": "agents/toolkits/index", "agents/toolkits/aws_sfn": "agents/toolkits/aws_sfn", "agents/toolkits/sql": "agents/toolkits/sql/index", + "agents/format_scratchpad": "agents/format_scratchpad", // base language base_language: "base_language/index", // tools diff --git a/langchain/src/agents/agent.ts b/langchain/src/agents/agent.ts index 2a8a30102ccf..f979ff7bf887 100644 --- a/langchain/src/agents/agent.ts +++ b/langchain/src/agents/agent.ts @@ -14,9 +14,11 @@ import { StructuredTool, Tool } from "../tools/base.js"; import { AgentActionOutputParser, AgentInput, + RunnableAgentInput, SerializedAgent, StoppingMethod, } from "./types.js"; +import { Runnable } from "../schema/runnable/base.js"; /** * Record type for arguments passed to output parsers. @@ -122,6 +124,52 @@ export abstract class BaseSingleActionAgent extends BaseAgent { ): Promise; } +/** + * Class representing a single action agent which accepts runnables. + * Extends the BaseSingleActionAgent class and provides methods for + * planning agent actions with runnables. + */ +export class RunnableAgent< + RunInput extends ChainValues & { + agent_scratchpad?: string | BaseMessage[]; + stop?: string[]; + }, + RunOutput extends AgentAction | AgentFinish +> extends BaseSingleActionAgent { + protected lc_runnable = true; + + lc_namespace = ["langchain", "agents", "runnable"]; + + runnable: Runnable; + + stop?: string[]; + + get inputKeys(): string[] { + return []; + } + + constructor(fields: RunnableAgentInput) { + super(); + this.runnable = fields.runnable; + this.stop = fields.stop; + } + + async plan( + steps: AgentStep[], + inputs: RunInput, + callbackManager?: CallbackManager + ): Promise { + const invokeInput = { ...inputs, steps }; + + const output = await this.runnable.invoke(invokeInput, { + callbacks: callbackManager, + runName: "RunnableAgent", + }); + + return output; + } +} + /** * Abstract base class for multi-action agents in LangChain. Extends the * BaseAgent class and provides additional functionality specific to @@ -249,6 +297,7 @@ export abstract class Agent extends BaseSingleActionAgent { constructor(input: AgentInput) { super(input); + this.llmChain = input.llmChain; this._allowedTools = input.allowedTools; this.outputParser = input.outputParser; diff --git a/langchain/src/agents/executor.ts b/langchain/src/agents/executor.ts index b64986dd46d5..62076359e0d2 100644 --- a/langchain/src/agents/executor.ts +++ b/langchain/src/agents/executor.ts @@ -1,25 +1,50 @@ import { BaseChain, ChainInputs } from "../chains/base.js"; -import { BaseMultiActionAgent, BaseSingleActionAgent } from "./agent.js"; +import { + BaseMultiActionAgent, + BaseSingleActionAgent, + RunnableAgent, +} from "./agent.js"; import { StoppingMethod } from "./types.js"; import { SerializedLLMChain } from "../chains/serde.js"; import { AgentAction, AgentFinish, AgentStep, + BaseMessage, ChainValues, } from "../schema/index.js"; import { CallbackManagerForChainRun } from "../callbacks/manager.js"; import { OutputParserException } from "../schema/output_parser.js"; -import { Tool, ToolInputParsingException } from "../tools/base.js"; +import { + StructuredTool, + Tool, + ToolInputParsingException, +} from "../tools/base.js"; +import { Runnable } from "../schema/runnable/base.js"; + +type ExtractToolType = T extends { ToolType: infer Tool } + ? Tool + : StructuredTool; /** * Interface defining the structure of input data for creating an * AgentExecutor. It extends ChainInputs and includes additional * properties specific to agent execution. */ -export interface AgentExecutorInput extends ChainInputs { - agent: BaseSingleActionAgent | BaseMultiActionAgent; - tools: this["agent"]["ToolType"][]; +export interface AgentExecutorInput< + RunInput extends ChainValues & { + agent_scratchpad?: string | BaseMessage[]; + stop?: string[]; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } = any, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + RunOutput extends AgentAction | AgentFinish = any +> extends ChainInputs { + agent: + | BaseSingleActionAgent + | BaseMultiActionAgent + | Runnable; + tools: ExtractToolType[]; returnIntermediateSteps?: boolean; maxIterations?: number; earlyStoppingMethod?: StoppingMethod; @@ -47,7 +72,15 @@ export class ExceptionTool extends Tool { * A chain managing an agent using tools. * @augments BaseChain */ -export class AgentExecutor extends BaseChain { +export class AgentExecutor< + RunInput extends ChainValues & { + agent_scratchpad?: string | BaseMessage[]; + stop?: string[]; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } = any, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + RunOutput extends AgentAction | AgentFinish = any +> extends BaseChain { static lc_name() { return "AgentExecutor"; } @@ -90,9 +123,16 @@ export class AgentExecutor extends BaseChain { return this.agent.returnValues; } - constructor(input: AgentExecutorInput) { + constructor(input: AgentExecutorInput) { + let agent: BaseSingleActionAgent | BaseMultiActionAgent; + if (Runnable.isRunnable(input.agent)) { + agent = new RunnableAgent({ runnable: input.agent }); + } else { + agent = input.agent; + } + super(input); - this.agent = input.agent; + this.agent = agent; this.tools = input.tools; this.handleParsingErrors = input.handleParsingErrors ?? this.handleParsingErrors; @@ -113,7 +153,15 @@ export class AgentExecutor extends BaseChain { } /** Create from agent and a list of tools. */ - static fromAgentAndTools(fields: AgentExecutorInput): AgentExecutor { + static fromAgentAndTools< + RunInput extends ChainValues & { + agent_scratchpad?: string | BaseMessage[]; + stop?: string[]; + }, + RunOutput extends AgentAction | AgentFinish + >( + fields: AgentExecutorInput + ): AgentExecutor { return new AgentExecutor(fields); } diff --git a/langchain/src/agents/format_scratchpad.ts b/langchain/src/agents/format_scratchpad.ts new file mode 100644 index 000000000000..8a4917d80ec1 --- /dev/null +++ b/langchain/src/agents/format_scratchpad.ts @@ -0,0 +1,31 @@ +import { renderTemplate } from "../prompts/template.js"; +import { + AgentStep, + BaseMessage, + AIMessage, + HumanMessage, +} from "../schema/index.js"; +import { TEMPLATE_TOOL_RESPONSE } from "./chat_convo/prompt.js"; + +/** + * Format a list of AgentSteps into a list of BaseMessage instances for + * agents that use OpenAI's API. Helpful for passing in previous agent + * step context into new iterations. + * + * @param steps A list of AgentSteps to format. + * @returns A list of BaseMessages. + */ +export function formatForOpenAIFunctions(steps: AgentStep[]): BaseMessage[] { + const thoughts: BaseMessage[] = []; + for (const step of steps) { + thoughts.push(new AIMessage(step.action.log)); + thoughts.push( + new HumanMessage( + renderTemplate(TEMPLATE_TOOL_RESPONSE, "f-string", { + observation: step.observation, + }) + ) + ); + } + return thoughts; +} diff --git a/langchain/src/agents/tests/agent.int.test.ts b/langchain/src/agents/tests/agent.int.test.ts index 72ab83e3dcdb..1f0b3defd960 100644 --- a/langchain/src/agents/tests/agent.int.test.ts +++ b/langchain/src/agents/tests/agent.int.test.ts @@ -3,12 +3,16 @@ import { expect, test } from "@jest/globals"; import { OpenAI } from "../../llms/openai.js"; import { OpenAIEmbeddings } from "../../embeddings/openai.js"; import { loadAgent } from "../load.js"; -import { AgentExecutor } from "../index.js"; +import { AgentExecutor, ZeroShotAgent } from "../index.js"; import { SerpAPI } from "../../tools/serpapi.js"; import { Calculator } from "../../tools/calculator.js"; import { initializeAgentExecutorWithOptions } from "../initialize.js"; import { WebBrowser } from "../../tools/webbrowser.js"; import { Tool } from "../../tools/base.js"; +import { ChatOpenAI } from "../../chat_models/openai.js"; +import { RunnableSequence } from "../../schema/runnable/base.js"; +import { OutputParserException } from "../../schema/output_parser.js"; +import { AIMessage } from "../../schema/index.js"; test("Run agent from hub", async () => { const model = new OpenAI({ temperature: 0, modelName: "text-babbage-001" }); @@ -33,7 +37,167 @@ test("Run agent from hub", async () => { input: "Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?", }); - console.log(res); + console.log( + { + res, + }, + "Run agent from hub response" + ); + expect(res.output).not.toEqual(""); + expect(res.output).not.toEqual("Agent stopped due to max iterations."); +}); + +test("Pass runnable to agent executor", async () => { + const model = new ChatOpenAI({ temperature: 0, modelName: "gpt-3.5-turbo" }); + const tools: Tool[] = [ + new SerpAPI(undefined, { + location: "Austin,Texas,United States", + hl: "en", + gl: "us", + }), + new Calculator(), + ]; + + const prompt = ZeroShotAgent.createPrompt(tools); + const outputParser = ZeroShotAgent.getDefaultOutputParser(); + + const runnable = RunnableSequence.from([ + { + input: (i: { input: string }) => i.input, + agent_scratchpad: (i: { input: string }) => i.input, + }, + prompt, + model, + outputParser, + ]); + + const executor = AgentExecutor.fromAgentAndTools({ + agent: runnable, + tools, + }); + const res = await executor.invoke({ + input: + "Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?", + }); + console.log( + { + res, + }, + "Pass runnable to agent executor" + ); + expect(res.output).not.toEqual(""); + expect(res.output).not.toEqual("Agent stopped due to max iterations."); +}); + +test("Custom output parser", async () => { + const model = new ChatOpenAI({ temperature: 0, modelName: "gpt-3.5-turbo" }); + const tools: Tool[] = [ + new SerpAPI(undefined, { + location: "Austin,Texas,United States", + hl: "en", + gl: "us", + }), + new Calculator(), + ]; + + const parser = (output: AIMessage) => { + const text = output.content; + if (text.includes("Final Answer:")) { + return { + returnValues: { + output: "We did it!", + }, + log: text, + }; + } + + const match = /Action:([\s\S]*?)(?:\nAction Input:([\s\S]*?))?$/.exec(text); + if (!match) { + throw new OutputParserException(`Could not parse LLM output: ${text}`); + } + + return { + tool: match[1].trim(), + toolInput: match[2] + ? match[2].trim().replace(/^("+)(.*?)(\1)$/, "$2") + : "", + log: text, + }; + }; + + const prompt = ZeroShotAgent.createPrompt(tools); + + const runnable = RunnableSequence.from([ + { + input: (i: { input: string }) => i.input, + agent_scratchpad: (i: { input: string }) => i.input, + }, + prompt, + model, + parser, + ]); + + const executor = AgentExecutor.fromAgentAndTools({ + agent: runnable, + tools, + }); + const res = await executor.invoke({ + input: + "Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?", + }); + console.log( + { + res, + }, + "Custom output parser" + ); + expect(res.output).toEqual("We did it!"); +}); + +test("Add a fallback method", async () => { + // Model should always fail since the model name passed does not exist. + const modelBase = new ChatOpenAI({ + modelName: "fake-model", + temperature: 10, + }); + + const modelLarge = new ChatOpenAI({ + modelName: "gpt-3.5-turbo-16k", + temperature: 0.6, + }); + + const model = modelBase.withFallbacks({ + fallbacks: [modelLarge], + }); + + const prompt = ZeroShotAgent.createPrompt([]); + const outputParser = ZeroShotAgent.getDefaultOutputParser(); + + const runnable = RunnableSequence.from([ + { + input: (i: { input: string }) => i.input, + agent_scratchpad: (i: { input: string }) => i.input, + }, + prompt, + model, + outputParser, + ]); + + const executor = AgentExecutor.fromAgentAndTools({ + agent: runnable, + tools: [], + }); + const res = await executor.invoke({ + input: "Is the sky blue? Response with a concise answer", + }); + console.log( + { + res, + }, + "Pass runnable to agent executor" + ); + expect(res.output).not.toEqual(""); + expect(res.output).not.toEqual("Agent stopped due to max iterations."); }); test("Run agent locally", async () => { @@ -56,8 +220,14 @@ test("Run agent locally", async () => { console.log(`Executing with input "${input}"...`); const result = await executor.call({ input }); - - console.log(`Got output ${result.output}`); + console.log( + { + result, + }, + "Run agent locally" + ); + expect(result.output).not.toEqual(""); + expect(result.output).not.toEqual("Agent stopped due to max iterations."); }); test("Run agent with an abort signal", async () => { @@ -142,10 +312,14 @@ test("Run tool web-browser", async () => { console.log(`Executing with input "${input}"...`); const result = await executor.call({ input }); - - console.log(`Got output ${result.output}`); - + console.log( + { + result, + }, + "Run tool web-browser" + ); expect(result.intermediateSteps.length).toBeGreaterThanOrEqual(1); expect(result.intermediateSteps[0].action.tool).toEqual("web-browser"); expect(result.output).not.toEqual(""); + expect(result.output).not.toEqual("Agent stopped due to max iterations."); }); diff --git a/langchain/src/agents/tests/structured_output_runnables.int.test.ts b/langchain/src/agents/tests/structured_output_runnables.int.test.ts new file mode 100644 index 000000000000..2e29a8a57675 --- /dev/null +++ b/langchain/src/agents/tests/structured_output_runnables.int.test.ts @@ -0,0 +1,127 @@ +import { zodToJsonSchema } from "zod-to-json-schema"; +import fs from "fs"; +import { z } from "zod"; +import { + AIMessage, + AgentAction, + AgentFinish, + AgentStep, +} from "../../schema/index.js"; +import { RunnableSequence } from "../../schema/runnable/base.js"; +import { ChatPromptTemplate, MessagesPlaceholder } from "../../prompts/chat.js"; +import { ChatOpenAI } from "../../chat_models/openai.js"; +import { createRetrieverTool } from "../toolkits/index.js"; +import { RecursiveCharacterTextSplitter } from "../../text_splitter.js"; +import { HNSWLib } from "../../vectorstores/hnswlib.js"; +import { OpenAIEmbeddings } from "../../embeddings/openai.js"; +import { formatToOpenAIFunction } from "../../tools/convert_to_openai.js"; +import { AgentExecutor } from "../executor.js"; +import { formatForOpenAIFunctions } from "../format_scratchpad.js"; + +/** Define a custom structured output parser. */ +const structuredOutputParser = ( + output: AIMessage +): AgentAction | AgentFinish => { + if (!("function_call" in output.additional_kwargs)) { + return { returnValues: { output: output.content }, log: output.content }; + } + + const functionCall = output.additional_kwargs.function_call; + const name = functionCall?.name as string; + const inputs = functionCall?.arguments as string; + + const jsonInput = JSON.parse(inputs); + + if (name === "response") { + return { returnValues: { ...jsonInput }, log: output.content }; + } + + return { + tool: name, + toolInput: jsonInput, + log: output.content, + }; +}; + +test("Pass custom structured output parsers", async () => { + /** Read text file & embed documents */ + const text = fs.readFileSync("../examples/state_of_the_union.txt", "utf8"); + const textSplitter = new RecursiveCharacterTextSplitter({ chunkSize: 1000 }); + let docs = await textSplitter.createDocuments([text]); + // Add fake source information + docs = docs.map((doc, i) => ({ + ...doc, + metadata: { + page_chunk: i, + }, + })); + /** Initialize docs & create retriever */ + const vectorStore = await HNSWLib.fromDocuments(docs, new OpenAIEmbeddings()); + const retriever = vectorStore.asRetriever(); + /** Instantiate the LLM */ + const llm = new ChatOpenAI({}); + /** Define the prompt template */ + const prompt = ChatPromptTemplate.fromMessages([ + ["system", "You are a helpful assistant"], + new MessagesPlaceholder("agent_scratchpad"), + ["user", "{input}"], + ]); + /** Define the response schema */ + const responseSchema = z.object({ + answer: z.string().describe("The final answer to respond to the user"), + sources: z + .array(z.string()) + .describe( + "List of page chunks that contain answer to the question. Only include a page chunk if it contains relevant information" + ), + }); + /** Create the response function */ + const responseOpenAIFunction = { + name: "response", + description: "Return the response to the user", + parameters: zodToJsonSchema(responseSchema), + }; + /** Convert retriever into a tool */ + const retrieverTool = createRetrieverTool(retriever, { + name: "state-of-union-retriever", + description: + "Query a retriever to get information about state of the union address", + }); + /** Bind both retriever and response functions to LLM */ + const llmWithTools = llm.bind({ + functions: [formatToOpenAIFunction(retrieverTool), responseOpenAIFunction], + }); + /** Create the runnable */ + const runnableAgent = RunnableSequence.from([ + { + input: (i: { input: string }) => i.input, + agent_scratchpad: (i: { input: string; steps: Array }) => + formatForOpenAIFunctions(i.steps), + }, + prompt, + llmWithTools, + structuredOutputParser, + ]); + /** Create the agent by passing in the runnable & tools */ + const executor = AgentExecutor.fromAgentAndTools({ + agent: runnableAgent, + tools: [retrieverTool], + }); + /** Call invoke on the agent */ + const res = await executor.invoke({ + input: "what did the president say about kentaji brown jackson", + }); + console.log({ + res, + }); + /** + { + res: { + answer: 'President mentioned that he nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. He described her as one of our nation’s top legal minds and stated that she will continue Justice Breyer’s legacy of excellence.', + sources: [ + 'And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence. A former top litigator in private practice. A former federal public defender. And from a family of public school educators and police officers. A consensus builder. Since she’s been nominated, she’s received a broad range of support—from the Fraternal Order of Police to former judges appointed by Democrats and Republicans.' + ] + } + } + */ +}); diff --git a/langchain/src/agents/types.ts b/langchain/src/agents/types.ts index 53f9241384de..d784724a4be4 100644 --- a/langchain/src/agents/types.ts +++ b/langchain/src/agents/types.ts @@ -1,7 +1,13 @@ import { LLMChain } from "../chains/llm_chain.js"; import { SerializedLLMChain } from "../chains/serde.js"; -import { AgentAction, AgentFinish } from "../schema/index.js"; +import { + AgentAction, + AgentFinish, + BaseMessage, + ChainValues, +} from "../schema/index.js"; import { BaseOutputParser } from "../schema/output_parser.js"; +import { Runnable } from "../schema/runnable/base.js"; /** * Interface defining the input for creating an agent. It includes the @@ -14,6 +20,23 @@ export interface AgentInput { allowedTools?: string[]; } +/** + * Interface defining the input for creating an agent that uses runnables. + * It includes the Runnable instance, and an optional list of stop strings. + */ +export interface RunnableAgentInput< + RunInput extends ChainValues & { + agent_scratchpad?: string | BaseMessage[]; + stop?: string[]; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } = any, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + RunOutput extends AgentAction | AgentFinish = any +> { + runnable: Runnable; + stop?: string[]; +} + /** * Abstract class representing an output parser specifically for agent * actions and finishes in LangChain. It extends the `BaseOutputParser` diff --git a/langchain/src/load/import_map.ts b/langchain/src/load/import_map.ts index a968697794d1..2ef3c03bbc15 100644 --- a/langchain/src/load/import_map.ts +++ b/langchain/src/load/import_map.ts @@ -3,6 +3,7 @@ export * as load__serializable from "../load/serializable.js"; export * as agents from "../agents/index.js"; export * as agents__toolkits from "../agents/toolkits/index.js"; +export * as agents__format_scratchpad from "../agents/format_scratchpad.js"; export * as base_language from "../base_language/index.js"; export * as tools from "../tools/index.js"; export * as chains from "../chains/index.js"; diff --git a/langchain/src/schema/runnable/base.ts b/langchain/src/schema/runnable/base.ts index 483d6680786b..3104e215abb0 100644 --- a/langchain/src/schema/runnable/base.ts +++ b/langchain/src/schema/runnable/base.ts @@ -547,7 +547,7 @@ export abstract class Runnable< // eslint-disable-next-line @typescript-eslint/no-explicit-any static isRunnable(thing: any): thing is Runnable { - return thing.lc_runnable; + return thing ? thing.lc_runnable : false; } } diff --git a/langchain/tsconfig.json b/langchain/tsconfig.json index 554d4d972eb8..305bbbd34a79 100644 --- a/langchain/tsconfig.json +++ b/langchain/tsconfig.json @@ -39,6 +39,7 @@ "src/agents/toolkits/index.ts", "src/agents/toolkits/aws_sfn.ts", "src/agents/toolkits/sql/index.ts", + "src/agents/format_scratchpad.ts", "src/base_language/index.ts", "src/tools/index.ts", "src/tools/aws_lambda.ts",