Skip to content

Commit

Permalink
⚡️ Chat with Agent (reworkd#977)
Browse files Browse the repository at this point in the history
  • Loading branch information
asim-shrestha authored Jul 6, 2023
1 parent 0643242 commit 41806d4
Show file tree
Hide file tree
Showing 17 changed files with 245 additions and 33 deletions.
2 changes: 1 addition & 1 deletion next/src/components/Input.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ const Input = (props: InputProps) => {
inputElement = (
<input
className={clsx(
"background-color-5 placeholder:text-color-tertiary text-color-primary border-color-1 border-focusVisible-1 border-hover-1 w-full rounded-xl border-2 py-1 text-sm tracking-wider outline-0 transition-all duration-200 sm:py-3 md:text-lg",
"background-color-5 placeholder:text-color-tertiary text-color-primary border-color-1 border-focusVisible-1 border-hover-1 w-full rounded-xl border-2 p-2 py-1 text-sm tracking-wider outline-0 transition-all duration-200 sm:py-3 md:text-lg",
disabled && "cursor-not-allowed",
left && "md:rounded-l-none",
small && "text-sm sm:py-[0]"
Expand Down
46 changes: 30 additions & 16 deletions next/src/components/console/ChatWindow.tsx
Original file line number Diff line number Diff line change
@@ -1,29 +1,36 @@
import type { ReactNode } from "react";
import React, { useEffect, useRef, useState } from "react";
import { useTranslation } from "next-i18next";
import FadeIn from "../motions/FadeIn";
import HideShow from "../motions/HideShow";
import clsx from "clsx";
import { ChatMessage } from "./ChatMessage";
import type { HeaderProps } from "./MacWindowHeader";
import { MacWindowHeader, messageListId } from "./MacWindowHeader";
import { FaArrowCircleDown } from "react-icons/fa";
import { FaArrowCircleDown, FaCommentDots } from "react-icons/fa";
import { useAgentStore } from "../../stores";
import { getTaskStatus, TASK_STATUS_EXECUTING } from "../../types/task";
import { ImSpinner2 } from "react-icons/im";
import Input from "../Input";
import Button from "../Button";

interface ChatControls {
value: string;
onChange: (string) => void;
handleChat: () => Promise<void>;
loading?: boolean;
}

interface ChatWindowProps extends HeaderProps {
children?: ReactNode;
setAgentRun?: (name: string, goal: string) => void;
visibleOnMobile?: boolean;
chatControls?: ChatControls;
}

const ChatWindow = ({
messages,
children,
title,
setAgentRun,
visibleOnMobile,
chatControls,
}: ChatWindowProps) => {
const [t] = useTranslation();
const [hasUserScrolled, setHasUserScrolled] = useState(false);
Expand Down Expand Up @@ -63,7 +70,7 @@ const ChatWindow = ({
>
<HideShow
showComponent={hasUserScrolled}
className="absolute bottom-2 right-6 cursor-pointer"
className="absolute bottom-14 right-6 cursor-pointer"
>
<FaArrowCircleDown
onClick={() => handleScrollToBottom("smooth")}
Expand All @@ -78,16 +85,6 @@ const ChatWindow = ({
onScroll={handleScroll}
id={messageListId}
>
{messages.map((message, index) => {
if (getTaskStatus(message) === TASK_STATUS_EXECUTING) {
return null;
}
return (
<FadeIn key={`${index}-${message.type}`}>
<ChatMessage message={message} />
</FadeIn>
);
})}
{children}
<div
className={clsx(
Expand All @@ -100,6 +97,23 @@ const ChatWindow = ({
<ImSpinner2 className="animate-spin" />
</div>
</div>
{chatControls && (
<div className="mt-auto flex flex-row gap-2 p-2 sm:p-4">
<Input
small
placeholder="Chat with your agent..."
value={chatControls.value}
onChange={(e) => chatControls?.onChange(e.target.value)}
/>
<Button
className="px-1 py-1 sm:px-3 md:py-1"
onClick={chatControls?.handleChat}
disabled={chatControls.loading}
>
<FaCommentDots />
</Button>
</div>
)}
</div>
);
};
Expand Down
10 changes: 3 additions & 7 deletions next/src/components/console/SummarizeButton.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { useAgentStore } from "../../stores";
import { useTaskStore } from "../../stores/taskStore";
import React, { useEffect, useState } from "react";
import React from "react";
import clsx from "clsx";
import Button from "../Button";

Expand All @@ -10,12 +10,8 @@ const Summarize = () => {
const tasksWithResults = useTaskStore.use
.tasks()
.filter((task) => task.status == "completed" && task.result !== "");
const [summarized, setSummarized] = useState(false);

// Reset the summarized state when the agent changes
useEffect(() => {
setSummarized(false);
}, [agent]);
const summarized = useAgentStore.use.summarized();
const setSummarized = useAgentStore.use.setSummarized();

if (!agent || lifecycle !== "stopped" || tasksWithResults.length < 1 || summarized) return null;

Expand Down
36 changes: 32 additions & 4 deletions next/src/pages/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@ import HelpDialog from "../components/dialog/HelpDialog";
import { useAuth } from "../hooks/useAuth";
import { useAgent } from "../hooks/useAgent";
import { isEmptyOrBlank } from "../utils/whitespace";
import { resetAllMessageSlices, useAgentStore, useMessageStore } from "../stores";
import {
resetAllAgentSlices,
resetAllMessageSlices,
useAgentStore,
useMessageStore,
} from "../stores";
import { serverSideTranslations } from "next-i18next/serverSideTranslations";
import { languages } from "../utils/languages";
import nextI18NextConfig from "../../next-i18next.config.js";
Expand All @@ -28,18 +33,20 @@ import { useRouter } from "next/router";
import { useAgentInputStore } from "../stores/agentInputStore";
import { MessageService } from "../services/agent/message-service";
import { DefaultAgentRunModel } from "../services/agent/agent-run-model";
import { resetAllTaskSlices } from "../stores/taskStore";
import { resetAllTaskSlices, useTaskStore } from "../stores/taskStore";
import { ChatWindowTitle } from "../components/console/ChatWindowTitle";
import { AgentApi } from "../services/agent/agent-api";
import { toApiModelSettings } from "../utils/interfaces";
import ExampleAgents from "../components/console/ExampleAgents";
import Summarize from "../components/console/SummarizeButton";
import AgentControls from "../components/console/AgentControls";
import { ChatMessage } from "../components/console/ChatMessage";

const Home: NextPage = () => {
const { t } = useTranslation("indexPage");
const addMessage = useMessageStore.use.addMessage();
const messages = useMessageStore.use.messages();
const tasks = useTaskStore.use.tasks();
const { query } = useRouter();

const setAgent = useAgentStore.use.setAgent();
Expand All @@ -53,6 +60,7 @@ const Home: NextPage = () => {
const setNameInput = useAgentInputStore.use.setNameInput();
const goalInput = useAgentInputStore.use.goalInput();
const setGoalInput = useAgentInputStore.use.setGoalInput();
const [chatInput, setChatInput] = React.useState("");
const [mobileVisibleWindow, setMobileVisibleWindow] = React.useState<"Chat" | "Tasks">("Chat");
const { settings } = useSettings();

Expand Down Expand Up @@ -116,7 +124,7 @@ const Home: NextPage = () => {
const handleRestart = () => {
resetAllMessageSlices();
resetAllTaskSlices();
setAgent(null);
resetAllAgentSlices();
};

const handleKeyPress = (
Expand Down Expand Up @@ -186,10 +194,30 @@ const Home: NextPage = () => {
<ChatWindow
messages={messages}
title={<ChatWindowTitle model={settings.customModelName} />}
setAgentRun={setAgentRun}
visibleOnMobile={mobileVisibleWindow === "Chat"}
chatControls={
agent
? {
value: chatInput,
onChange: (value: string) => {
setChatInput(value);
},
handleChat: async () => {
await agent?.chat(chatInput);
},
loading: tasks.length == 0,
}
: undefined
}
>
{messages.length === 0 && <ExampleAgents setAgentRun={setAgentRun} />}
{messages.map((message, index) => {
return (
<FadeIn key={`${index}-${message.type}`}>
<ChatMessage message={message} />
</FadeIn>
);
})}
<Summarize />
</ChatWindow>
<TaskWindow visibleOnMobile={mobileVisibleWindow === "Tasks"} />
Expand Down
56 changes: 56 additions & 0 deletions next/src/services/agent/agent-work/chat-work.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import type { Message } from "../../../types/message";
import { v1 } from "uuid";
import { streamText } from "../../stream-utils";
import { toApiModelSettings } from "../../../utils/interfaces";
import type AgentWork from "./agent-work";
import type AutonomousAgent from "../autonomous-agent";

export default class ChatWork implements AgentWork {
constructor(private parent: AutonomousAgent, private message: string) {}

run = async () => {
const executionMessage: Message = {
type: "task",
status: "completed",
value: `Response for '${this.message}'`,
id: v1(),
info: "Loading...",
};
this.parent.messageService.sendMessage({ ...executionMessage });

// TODO: this should be moved to the api layer
await streamText(
"/api/agent/chat",
{
run_id: this.parent.api.runId,
goal: this.parent.model.getGoal(),
model_settings: toApiModelSettings(this.parent.modelSettings, this.parent.session),
message: this.message,
results: this.parent.model
.getCompletedTasks()
.filter((task) => task.result && task.result !== "")
.map((task) => task.result || ""),
},
this.parent.api.props.session?.accessToken || "",
() => {
executionMessage.info = "";
},
(text) => {
executionMessage.info += text;
this.parent.messageService.updateMessage(executionMessage);
},
() => this.parent.model.getLifecycle() === "stopped"
);
this.parent.api.saveMessages([executionMessage]);
};

// eslint-disable-next-line @typescript-eslint/require-await
conclude = async () => void 0;

next = () => undefined;

onError = (e: unknown): boolean => {
this.parent.messageService.sendErrorMessage(e);
return true;
};
}
14 changes: 11 additions & 3 deletions next/src/services/agent/autonomous-agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import type AgentWork from "./agent-work/agent-work";
import { withRetries } from "../api-utils";
import type { Message } from "../../types/message";
import SummarizeWork from "./agent-work/summarize-work";
import ChatWork from "./agent-work/chat-work";

class AutonomousAgent {
model: AgentRunModel;
Expand Down Expand Up @@ -54,7 +55,7 @@ class AutonomousAgent {

// Get and run the next work item
const work = this.workLog[0];
await this.runWork(work);
await this.runWork(work, () => this.model.getLifecycle() === "stopped");

this.workLog.shift();
if (this.model.getLifecycle() !== "running") {
Expand Down Expand Up @@ -83,12 +84,12 @@ class AutonomousAgent {
/*
* Runs a provided work object with error handling and retries
*/
private async runWork(work: AgentWork) {
private async runWork(work: AgentWork, shouldStop: () => boolean = () => false) {
const RETRY_TIMEOUT = 2000;

await withRetries(
async () => {
if (this.model.getLifecycle() === "stopped") return;
if (shouldStop()) return;
await work.run();
},
async (e) => {
Expand Down Expand Up @@ -138,6 +139,13 @@ class AutonomousAgent {
this.model.setLifecycle("stopped");
}

async chat(message: string) {
if (this.model.getLifecycle() == "running") this.pauseAgent();
const chatWork = new ChatWork(this, message);
await this.runWork(chatWork);
await chatWork.conclude();
}

async createTaskMessages(tasks: string[]) {
const TIMOUT_SHORT = 150;
const messages: Message[] = [];
Expand Down
1 change: 1 addition & 0 deletions next/src/services/agent/message-service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ export class MessageService {
sendErrorMessage = (e: unknown) => {
let message = "An unknown error occurred. Please try again later.";
if (typeof e == "string") message = e;
else if (e instanceof Error) message = e.message;
else if (axios.isAxiosError(e) && e.message == "Network Error") {
message = "Error attempting to connect to the server.";
} else if (axios.isAxiosError(e)) {
Expand Down
10 changes: 10 additions & 0 deletions next/src/stores/agentStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ interface AgentSlice {
agent: AutonomousAgent | null;
lifecycle: AgentLifecycle;
setLifecycle: (AgentLifecycle) => void;
summarized: boolean;
setSummarized: (boolean) => void;
isAgentThinking: boolean;
setIsAgentThinking: (isThinking: boolean) => void;
setAgent: (newAgent: AutonomousAgent | null) => void;
Expand All @@ -18,6 +20,7 @@ interface AgentSlice {
const initialAgentState = {
agent: null,
lifecycle: "offline" as const,
summarized: false,
isAgentThinking: false,
isAgentPaused: undefined,
};
Expand All @@ -38,6 +41,11 @@ const createAgentSlice: StateCreator<AgentSlice> = (set, get) => {
lifecycle: lifecycle,
}));
},
setSummarized: (summarized: boolean) => {
set(() => ({
summarized: summarized,
}));
},
setIsAgentThinking: (isThinking: boolean) => {
set(() => ({
isAgentThinking: isThinking,
Expand Down Expand Up @@ -83,3 +91,5 @@ export const useAgentStore = createSelectors(
)
)
);

export const resetAllAgentSlices = () => resetters.forEach((resetter) => resetter());
3 changes: 2 additions & 1 deletion next/src/utils/interfaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ export const toApiModelSettings = (modelSettings: ModelSettings, session?: Sessi
};

export interface RequestBody {
run_id?: string;
model_settings: ApiModelSettings;
goal: string;
task?: string;
Expand All @@ -32,5 +33,5 @@ export interface RequestBody {
completed_tasks?: string[];
analysis?: Analysis;
tool_names?: string[];
run_id?: string;
message?: string; // Used for the chat endpoint
}
1 change: 1 addition & 0 deletions platform/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ RUN apt-get update && apt-get install -y \
default-libmysqlclient-dev \
pkg-config \
gcc \
ca-certificates \
pkg-config \
&& rm -rf /var/lib/apt/lists/*

Expand Down
6 changes: 6 additions & 0 deletions platform/reworkd_platform/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"execute",
"create",
"summarize",
"chat",
]

LLM_MODEL_MAX_TOKENS: Dict[LLM_Model, int] = {
Expand Down Expand Up @@ -72,6 +73,11 @@ class AgentSummarize(AgentRun):
results: List[str] = Field(default=[])


class AgentChat(AgentRun):
message: str
results: List[str] = Field(default=[])


class NewTasksResponse(BaseModel):
run_id: str
new_tasks: List[str] = Field(alias="newTasks")
Expand Down
Loading

0 comments on commit 41806d4

Please sign in to comment.