Skip to content

Commit

Permalink
♻ Refactor agents and prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
asim-shrestha committed Apr 19, 2023
1 parent 511a008 commit dea173b
Show file tree
Hide file tree
Showing 13 changed files with 149 additions and 177 deletions.
2 changes: 1 addition & 1 deletion __tests__/extract-array.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { extractArray } from "../src/utils/chain";
import { extractArray } from "../src/utils/helpers";

describe("Strings should be extracted from arrays correctly", () => {
it("simple", () => {
Expand Down
8 changes: 4 additions & 4 deletions src/components/AutonomousAgent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,10 @@ class AutonomousAgent {
if (!env.NEXT_PUBLIC_FF_MOCK_MODE_ENABLED) {
await testConnection(this.modelSettings);
}
return await AgentService.startAgent(this.modelSettings, this.goal);
return await AgentService.startGoalAgent(this.modelSettings, this.goal);
}

const res = await axios.post(`/api/chain`, {
const res = await axios.post(`/api/start`, {
modelSettings: this.modelSettings,
goal: this.goal,
});
Expand All @@ -158,7 +158,7 @@ class AutonomousAgent {
result: string
): Promise<string[]> {
if (this.shouldRunClientSide()) {
return await AgentService.createAgent(
return await AgentService.createTasksAgent(
this.modelSettings,
this.goal,
this.tasks,
Expand All @@ -182,7 +182,7 @@ class AutonomousAgent {

async executeTask(task: string): Promise<string> {
if (this.shouldRunClientSide()) {
return await AgentService.executeAgent(
return await AgentService.executeTaskAgent(
this.modelSettings,
this.goal,
task
Expand Down
2 changes: 1 addition & 1 deletion src/components/Input.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import React from "react";
import Label from "./Label";
import clsx from "clsx";
import Combobox from "./Combobox";
import isArrayOfType from "../utils/helpers";
import { isArrayOfType } from "../utils/helpers";
import type { toolTipProperties } from "./types";

interface InputProps {
Expand Down
2 changes: 1 addition & 1 deletion src/pages/api/create.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ const handler = async (request: NextRequest) => {
return;
}

const newTasks = await AgentService.createAgent(
const newTasks = await AgentService.createTasksAgent(
modelSettings,
goal,
tasks,
Expand Down
6 changes: 5 additions & 1 deletion src/pages/api/execute.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ const handler = async (request: NextRequest) => {
return;
}

const response = await AgentService.executeAgent(modelSettings, goal, task);
const response = await AgentService.executeTaskAgent(
modelSettings,
goal,
task
);
return NextResponse.json({
response: response,
});
Expand Down
2 changes: 1 addition & 1 deletion src/pages/api/chain.ts → src/pages/api/start.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export const config = {
const handler = async (request: NextRequest) => {
try {
const { modelSettings, goal } = (await request.json()) as RequestBody;
const newTasks = await AgentService.startAgent(modelSettings, goal);
const newTasks = await AgentService.startGoalAgent(modelSettings, goal);
return NextResponse.json({ newTasks });
} catch (e) {}

Expand Down
2 changes: 0 additions & 2 deletions src/server/api/root.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import { createTRPCRouter } from "./trpc";
import { exampleRouter } from "./routers/example";
import { chainRouter } from "./routers/chain";
import { agentRouter } from "./routers/agentRouter";
import { accountRouter } from "./routers/account";

Expand All @@ -11,7 +10,6 @@ import { accountRouter } from "./routers/account";
*/
export const appRouter = createTRPCRouter({
example: exampleRouter,
chain: chainRouter,
agent: agentRouter,
account: accountRouter,
});
Expand Down
16 changes: 0 additions & 16 deletions src/server/api/routers/chain.ts

This file was deleted.

106 changes: 59 additions & 47 deletions src/services/agent-service.ts
Original file line number Diff line number Diff line change
@@ -1,98 +1,110 @@
import {
createModel,
executeCreateTaskAgent,
executeTaskAgent,
extractArray,
realTasksFilter,
startGoalAgent,
} from "../utils/chain";
startGoalPrompt,
executeTaskPrompt,
createTasksPrompt,
} from "../utils/prompts";
import type { ModelSettings } from "../utils/types";
import { env } from "../env/client.mjs";
import { LLMChain } from "langchain/chains";
import { extractTasks } from "../utils/helpers";

async function startAgent(modelSettings: ModelSettings, goal: string) {
const completion = await startGoalAgent(createModel(modelSettings), goal);
console.log(typeof completion.text);
async function startGoalAgent(modelSettings: ModelSettings, goal: string) {
const completion = await new LLMChain({
llm: createModel(modelSettings),
prompt: startGoalPrompt,
}).call({
goal,
});
console.log("Completion:" + (completion.text as string));
return extractArray(completion.text as string).filter(realTasksFilter);
return extractTasks(completion.text as string, []);
}

async function createAgent(
async function executeTaskAgent(
modelSettings: ModelSettings,
goal: string,
tasks: string[],
lastTask: string,
result: string,
completedTasks: string[] | undefined
task: string
) {
const completion = await executeCreateTaskAgent(
createModel(modelSettings),
const completion = await new LLMChain({
llm: createModel(modelSettings),
prompt: executeTaskPrompt,
}).call({
goal,
tasks,
lastTask,
result
);
task,
});

return extractArray(completion.text as string)
.filter(realTasksFilter)
.filter((task) => !(completedTasks || []).includes(task));
return completion.text as string;
}

async function executeAgent(
async function createTasksAgent(
modelSettings: ModelSettings,
goal: string,
task: string
tasks: string[],
lastTask: string,
result: string,
completedTasks: string[] | undefined
) {
const completion = await executeTaskAgent(
createModel(modelSettings),
const completion = await new LLMChain({
llm: createModel(modelSettings),
prompt: createTasksPrompt,
}).call({
goal,
task
);
return completion.text as string;
tasks,
lastTask,
result,
});

return extractTasks(completion.text as string, completedTasks || []);
}

interface AgentService {
startAgent: (modelSettings: ModelSettings, goal: string) => Promise<string[]>;
createAgent: (
startGoalAgent: (
modelSettings: ModelSettings,
goal: string
) => Promise<string[]>;
executeTaskAgent: (
modelSettings: ModelSettings,
goal: string,
task: string
) => Promise<string>;
createTasksAgent: (
modelSettings: ModelSettings,
goal: string,
tasks: string[],
lastTask: string,
result: string,
completedTasks: string[] | undefined
) => Promise<string[]>;
executeAgent: (
modelSettings: ModelSettings,
goal: string,
task: string
) => Promise<string>;
}

const OpenAIAgentService: AgentService = {
startAgent: startAgent,
createAgent: createAgent,
executeAgent: executeAgent,
startGoalAgent: startGoalAgent,
executeTaskAgent: executeTaskAgent,
createTasksAgent: createTasksAgent,
};

const MockAgentService: AgentService = {
startAgent: async (modelSettings, goal) => {
return ["Task 1"];
startGoalAgent: async (modelSettings, goal) => {
return await new Promise((resolve) => resolve(["Task 1"]));
},
createAgent: async (

createTasksAgent: async (
modelSettings: ModelSettings,
goal: string,
tasks: string[],
lastTask: string,
result: string,
completedTasks: string[] | undefined
) => {
return ["Task 4"];
return await new Promise((resolve) => resolve(["Task 4"]));
},
executeAgent: async (

executeTaskAgent: async (
modelSettings: ModelSettings,
goal: string,
task: string
) => {
return "Result " + task;
return await new Promise((resolve) => resolve("Result: " + task));
},
};

Expand Down
100 changes: 0 additions & 100 deletions src/utils/chain.ts

This file was deleted.

2 changes: 1 addition & 1 deletion src/utils/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ export const GPT_35_TURBO = "gpt-3.5-turbo";
export const GPT_4 = "gpt-4";
export const GPT_MODEL_NAMES = [GPT_35_TURBO, GPT_4];

export const DEFAULT_MAX_LOOPS_FREE = 4;
export const DEFAULT_MAX_LOOPS_FREE = 3;
export const DEFAULT_MAX_LOOPS_PAID = 16;
export const DEFAULT_MAX_LOOPS_CUSTOM_API_KEY = 50;
Loading

0 comments on commit dea173b

Please sign in to comment.