Skip to content

Commit

Permalink
[major]: Simplify yielding UI elements
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul committed Jul 31, 2024
1 parent aab1ba6 commit d631992
Show file tree
Hide file tree
Showing 8 changed files with 209 additions and 166 deletions.
2 changes: 1 addition & 1 deletion ai/graph.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ const invokeTools = async (
throw new Error("No tool found in tool map.");
}
const toolResult = await selectedTool.invoke(
state.toolCall.parameters,
state.toolCall.parameters as any,
config,
);
return {
Expand Down
15 changes: 7 additions & 8 deletions ai/tools/firecrawl.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { WebLoading, Web } from "@/components/prebuilt/web";
import { createRunnableUI } from "@/utils/server";
import { DynamicStructuredTool } from "@langchain/core/tools";
import { tool } from "@langchain/core/tools";
import { FireCrawlLoader } from "@langchain/community/document_loaders/web/firecrawl";
import { z } from "zod";

Expand Down Expand Up @@ -31,14 +31,13 @@ export async function webData(input: z.infer<typeof webSchema>) {
};
}

export const websiteDataTool = new DynamicStructuredTool({
name: "get_web_data",
description: "A tool to fetch the current website data, given a url.",
schema: webSchema,
func: async (input, config) => {
const stream = await createRunnableUI(config, <WebLoading />);
export const websiteDataTool = tool(async (input, config) => {
const stream = await createRunnableUI(config, <WebLoading />);
const data = await webData(input);
stream.done(<Web {...data} />);
return JSON.stringify(data, null);
},
}, {
name: "get_web_data",
description: "A tool to fetch the current website data, given a url.",
schema: webSchema,
});
42 changes: 29 additions & 13 deletions ai/tools/github_repo.tsx
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import { z } from "zod";
import { Octokit } from "octokit";
import { DynamicStructuredTool } from "@langchain/core/tools";
import { createRunnableUI } from "@/utils/server";
import { tool } from "@langchain/core/tools";
import { createRunnableUI, CUSTOM_UI_YIELD_NAME } from "@/utils/server";
import { Github, GithubLoading } from "@/components/prebuilt/github";
import { dispatchCustomEvent } from "@langchain/core/callbacks/dispatch/web";

const githubRepoToolSchema = z.object({
owner: z.string().describe("The name of the repository owner."),
Expand Down Expand Up @@ -36,20 +37,35 @@ async function githubRepoTool(input: z.infer<typeof githubRepoToolSchema>) {
}
}

export const githubTool = new DynamicStructuredTool({
name: "github_repo",
description:
"A tool to fetch details of a Github repository. Given owner and repo names, this tool will return the repo description, stars, and primary language.",
schema: githubRepoToolSchema,
func: async (input, config) => {
const stream = await createRunnableUI(config, <GithubLoading />);
export const githubTool = tool(async (input, config) => {
// const stream = await createRunnableUI(config, <GithubLoading />);
await dispatchCustomEvent(CUSTOM_UI_YIELD_NAME, {
output: {
value: <GithubLoading />,
type: "append",
}
}, config);
const result = await githubRepoTool(input);
if (typeof result === "string") {
// Failed to parse, return error message
stream.done(<p>{result}</p>);
await dispatchCustomEvent(CUSTOM_UI_YIELD_NAME, {
output: {
value: <p>{result}</p>,
type: "update",
}
}, config);
return result;
}
stream.done(<Github {...result} />);
await dispatchCustomEvent(CUSTOM_UI_YIELD_NAME, {
output: {
value: <Github {...result} />,
type: "update",
}
}, config);
return JSON.stringify(result, null);
},
});
}, {
name: "github_repo",
description:
"A tool to fetch details of a Github repository. Given owner and repo names, this tool will return the repo description, stars, and primary language.",
schema: githubRepoToolSchema,
})
15 changes: 7 additions & 8 deletions ai/tools/invoice.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { z } from "zod";
import { v4 as uuidv4 } from "uuid";
import { InvoiceLoading, Invoice } from "@/components/prebuilt/invoice";
import { createRunnableUI } from "@/utils/server";
import { DynamicStructuredTool } from "@langchain/core/tools";
import { DynamicStructuredTool, tool } from "@langchain/core/tools";

const LineItemSchema = z.object({
id: z
Expand Down Expand Up @@ -51,14 +51,13 @@ export const InvoiceSchema = z.object({
),
});

export const invoiceTool = new DynamicStructuredTool({
export const invoiceTool = tool(async (input, config) => {
const stream = await createRunnableUI(config, <InvoiceLoading />);
stream.done(<Invoice {...input} />);
return JSON.stringify(input, null);
}, {
name: "get_order_invoice",
description:
"A tool to fetch the invoice from an order. This should only be called if a user uploads an image/receipt of an order.",
schema: InvoiceSchema,
func: async (input, config) => {
const stream = await createRunnableUI(config, <InvoiceLoading />);
stream.done(<Invoice {...input} />);
return JSON.stringify(input, null);
},
});
})
19 changes: 10 additions & 9 deletions ai/tools/weather.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import {
CurrentWeather,
} from "@/components/prebuilt/weather";
import { createRunnableUI } from "@/utils/server";
import { DynamicStructuredTool } from "@langchain/core/tools";
import { DynamicStructuredTool, tool } from "@langchain/core/tools";
import { z } from "zod";

export const weatherSchema = z.object({
Expand All @@ -18,7 +18,9 @@ export const weatherSchema = z.object({
.describe("The two letter country abbreviation to get weather for"),
});

export async function weatherData(input: z.infer<typeof weatherSchema>) {
export type WeatherToolSchema = z.infer<typeof weatherSchema>;

export async function weatherData(input: WeatherToolSchema) {
const geoCodeApiKey = process.env.GEOCODE_API_KEY;
if (!geoCodeApiKey) {
throw new Error("Missing GEOCODE_API_KEY secret.");
Expand Down Expand Up @@ -59,15 +61,14 @@ export async function weatherData(input: z.infer<typeof weatherSchema>) {
};
}

export const weatherTool = new DynamicStructuredTool({
export const weatherTool = tool(async (input, config) => {
const stream = await createRunnableUI(config, <CurrentWeatherLoading />);
const data = await weatherData(input);
stream.done(<CurrentWeather {...data} />);
return JSON.stringify(data, null);
}, {
name: "get_weather",
description:
"A tool to fetch the current weather, given a city and state. If the city/state is not provided, ask the user for both the city and state.",
schema: weatherSchema,
func: async (input, config) => {
const stream = await createRunnableUI(config, <CurrentWeatherLoading />);
const data = await weatherData(input);
stream.done(<CurrentWeather {...data} />);
return JSON.stringify(data, null);
},
});
14 changes: 7 additions & 7 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
"format": "yarn prettier --write ./app ./ai ./components ./utils ./lib"
},
"dependencies": {
"@langchain/community": "^0.2.10",
"@langchain/core": "^0.2.6",
"@langchain/langgraph": "^0.0.25",
"@langchain/openai": "^0.0.34",
"@langchain/community": "^0.2.22",
"@langchain/core": "^0.2.18",
"@langchain/langgraph": "^0.0.31",
"@langchain/openai": "^0.2.5",
"@mendable/firecrawl-js": "^0.0.26",
"@radix-ui/react-avatar": "^1.0.4",
"@radix-ui/react-dropdown-menu": "^2.0.6",
Expand All @@ -27,12 +27,12 @@
"@radix-ui/react-switch": "^1.0.3",
"@radix-ui/react-tabs": "^1.0.4",
"@radix-ui/react-tooltip": "^1.0.7",
"ai": "^3.1.16",
"ai": "^3.2.43",
"class-variance-authority": "^0.7.0",
"clsx": "^2.1.1",
"date-fns": "^3.6.0",
"jotai": "^2.8.2",
"langchain": "^0.2.3",
"langchain": "^0.2.12",
"lucide-react": "^0.379.0",
"next": "14.2.3",
"octokit": "^4.0.2",
Expand All @@ -58,6 +58,6 @@
"typescript": "^5"
},
"resolutions": {
"@langchain/core": "0.2.6"
"@langchain/core": "0.2.18"
}
}
34 changes: 18 additions & 16 deletions utils/server.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,16 @@ import "server-only";

import { ReactNode, isValidElement } from "react";
import { AIProvider } from "./client";
import {
CallbackManagerForToolRun,
CallbackManagerForRetrieverRun,
CallbackManagerForChainRun,
} from "@langchain/core/callbacks/manager";
import { createStreamableUI, createStreamableValue } from "ai/rsc";
import { Runnable, RunnableLambda } from "@langchain/core/runnables";
import { Runnable, RunnableConfig, RunnableLambda } from "@langchain/core/runnables";
import { CompiledStateGraph } from "@langchain/langgraph";
import { StreamEvent } from "@langchain/core/tracers/log_stream";
import { AIMessage } from "@/ai/message";

export const dynamic = 'force-dynamic';

const STREAM_UI_RUN_NAME = "stream_runnable_ui";
export const CUSTOM_UI_YIELD_NAME = "__yield_ui__";

/**
* Executes `streamEvents` method on a runnable
Expand Down Expand Up @@ -42,9 +40,17 @@ export function streamRunnableUI<RunInput, RunOutput>(
for await (const streamEvent of (
runnable as Runnable<RunInput, RunOutput>
).streamEvents(inputs, {
version: "v1",
version: "v2",
})) {
if (
if (streamEvent.name === CUSTOM_UI_YIELD_NAME) {
if (isValidElement(streamEvent.data.output.value)) {
if (streamEvent.data.output.type === "append") {
ui.append(streamEvent.data.output.value);
} else if (streamEvent.data.output.type === "update") {
ui.update(streamEvent.data.output.value);
}
}
} else if (
streamEvent.name === STREAM_UI_RUN_NAME &&
streamEvent.event === "on_chain_end"
) {
Expand Down Expand Up @@ -88,19 +94,15 @@ export function streamRunnableUI<RunInput, RunOutput>(
* Yields an UI element within a runnable,
* which can be streamed to the client via `streamRunnableUI`
*
* @param callbackManager callback
* @param config RunnableConfig
* @param initialValue Initial React node to be sent to the client
* @returns Vercel AI RSC compatible streamable UI
*/
export const createRunnableUI = async (
callbackManager:
| CallbackManagerForToolRun
| CallbackManagerForRetrieverRun
| CallbackManagerForChainRun
| undefined,
config: RunnableConfig | undefined,
initialValue?: React.ReactNode,
): Promise<ReturnType<typeof createStreamableUI>> => {
if (!callbackManager) {
if (!config) {
throw new Error("Callback manager is not defined");
}

Expand All @@ -109,7 +111,7 @@ export const createRunnableUI = async (
return ui;
}).withConfig({ runName: STREAM_UI_RUN_NAME });

return lambda.invoke(initialValue, { callbacks: callbackManager.getChild() });
return lambda.invoke(initialValue, config);
};

/**
Expand Down
Loading

0 comments on commit d631992

Please sign in to comment.