Skip to content

Commit

Permalink
Multi-Action agent (langchain-ai#690)
Browse files Browse the repository at this point in the history
* refactor

* format
  • Loading branch information
agola11 authored Apr 10, 2023
1 parent d37366b commit 17a485e
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 32 deletions.
49 changes: 33 additions & 16 deletions langchain/src/agents/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class ParseError extends Error {
}
}

export abstract class BaseSingleActionAgent {
export abstract class BaseAgent {
abstract get inputKeys(): string[];

get returnValues(): string[] {
Expand All @@ -43,19 +43,6 @@ export abstract class BaseSingleActionAgent {
throw new Error("Not implemented");
}

/**
* Decide what to do given some input.
*
* @param steps - Steps the LLM has taken so far, along with observations from each.
* @param inputs - User inputs.
*
* @returns Action specifying what tool to use.
*/
abstract plan(
steps: AgentStep[],
inputs: ChainValues
): Promise<AgentAction | AgentFinish>;

/**
* Return response when agent has been stopped due to max iterations
*/
Expand Down Expand Up @@ -85,6 +72,36 @@ export abstract class BaseSingleActionAgent {
}
}

export abstract class BaseSingleActionAgent extends BaseAgent {
/**
* Decide what to do, given some input.
*
* @param steps - Steps the LLM has taken so far, along with observations from each.
* @param inputs - User inputs.
*
* @returns Action specifying what tool to use.
*/
abstract plan(
steps: AgentStep[],
inputs: ChainValues
): Promise<AgentAction | AgentFinish>;
}

export abstract class BaseMultiActionAgent extends BaseAgent {
/**
* Decide what to do, given some input.
*
* @param steps - Steps the LLM has taken so far, along with observations from each.
* @param inputs - User inputs.
*
* @returns Actions specifying what tools to use.
*/
abstract plan(
steps: AgentStep[],
inputs: ChainValues
): Promise<AgentAction[] | AgentFinish>;
}

export interface LLMSingleActionAgentInput {
llmChain: LLMChain;
outputParser: AgentActionOutputParser;
Expand Down Expand Up @@ -183,8 +200,8 @@ export abstract class Agent extends BaseSingleActionAgent {
/**
* Create a prompt for this class
*
* @param tools - List of tools the agent will have access to, used to format the prompt.
* @param fields - Additional fields used to format the prompt.
* @param _tools - List of tools the agent will have access to, used to format the prompt.
* @param _fields - Additional fields used to format the prompt.
*
* @returns A PromptTemplate assembled from the given tools and fields.
* */
Expand Down
68 changes: 52 additions & 16 deletions langchain/src/agents/executor.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import { BaseChain, ChainInputs } from "../chains/index.js";
import { BaseSingleActionAgent } from "./agent.js";
import { BaseMultiActionAgent, BaseSingleActionAgent } from "./agent.js";
import { Tool } from "./tools/base.js";
import { StoppingMethod } from "./types.js";
import { SerializedLLMChain } from "../chains/serde.js";
import { AgentFinish, AgentStep, ChainValues } from "../schema/index.js";
import {
AgentAction,
AgentFinish,
AgentStep,
ChainValues,
} from "../schema/index.js";

interface AgentExecutorInput extends ChainInputs {
agent: BaseSingleActionAgent;
agent: BaseSingleActionAgent | BaseMultiActionAgent;
tools: Tool[];
returnIntermediateSteps?: boolean;
maxIterations?: number;
Expand All @@ -18,7 +23,7 @@ interface AgentExecutorInput extends ChainInputs {
* @augments BaseChain
*/
export class AgentExecutor extends BaseChain {
agent: BaseSingleActionAgent;
agent: BaseSingleActionAgent | BaseMultiActionAgent;

tools: Tool[];

Expand All @@ -36,6 +41,16 @@ export class AgentExecutor extends BaseChain {
super(input.memory, input.verbose, input.callbackManager);
this.agent = input.agent;
this.tools = input.tools;
// eslint-disable-next-line no-instanceof/no-instanceof
if (this.agent instanceof BaseMultiActionAgent) {
for (const tool of this.tools) {
if (tool.returnDirect) {
throw new Error(
`Tool with return direct ${tool.name} not supported for multi-action agent.`
);
}
}
}
this.returnIntermediateSteps =
input.returnIntermediateSteps ?? this.returnIntermediateSteps;
this.maxIterations = input.maxIterations ?? this.maxIterations;
Expand Down Expand Up @@ -71,23 +86,44 @@ export class AgentExecutor extends BaseChain {
};

while (this.shouldContinue(iterations)) {
const action = await this.agent.plan(steps, inputs);
if ("returnValues" in action) {
return getOutput(action);
const output = await this.agent.plan(steps, inputs);
// Check if the agent has finished
if ("returnValues" in output) {
return getOutput(output);
}

let actions: AgentAction[];
if (Array.isArray(output)) {
actions = output as AgentAction[];
} else {
actions = [output as AgentAction];
}
await this.callbackManager.handleAgentAction(action, this.verbose);

const tool = toolsByName[action.tool?.toLowerCase()];
const observation = tool
? await tool.call(action.toolInput, this.verbose)
: `${action.tool} is not a valid tool, try another one.`;
steps.push({ action, observation });
if (tool?.returnDirect) {

const newSteps = await Promise.all(
actions.map(async (action) => {
await this.callbackManager.handleAgentAction(action, this.verbose);

const tool = toolsByName[action.tool?.toLowerCase()];
const observation = tool
? await tool.call(action.toolInput, this.verbose)
: `${action.tool} is not a valid tool, try another one.`;

return { action, observation };
})
);

steps.push(...newSteps);

const lastStep = steps[steps.length - 1];
const lastTool = toolsByName[lastStep.action.tool?.toLowerCase()];

if (lastTool?.returnDirect) {
return getOutput({
returnValues: { [this.agent.returnValues[0]]: observation },
returnValues: { [this.agent.returnValues[0]]: lastStep.observation },
log: "",
});
}

iterations += 1;
}

Expand Down

0 comments on commit 17a485e

Please sign in to comment.