Skip to content

Commit

Permalink
improve ui streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul committed Jun 7, 2024
1 parent 4256a7b commit fc7c860
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 37 deletions.
2 changes: 1 addition & 1 deletion ai/tools/github_repo.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ export const githubTool = new DynamicStructuredTool({
"A tool to fetch details of a Github repository. Given owner and repo names, this tool will return the repo description, stars, and primary language.",
schema: githubRepoToolSchema,
func: async (input, config) => {
const stream = createRunnableUI(config, <GithubLoading />);
const stream = await createRunnableUI(config, <GithubLoading />);
const result = await githubRepoTool(input);
if (typeof result === "string") {
// Failed to parse, return error message
Expand Down
2 changes: 1 addition & 1 deletion ai/tools/invoice.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ export const invoiceTool = new DynamicStructuredTool({
"A tool to fetch the invoice from an order. This should only be called if a user uploads an image/receipt of an order.",
schema: InvoiceSchema,
func: async (input, config) => {
const stream = createRunnableUI(config, <InvoiceLoading />);
const stream = await createRunnableUI(config, <InvoiceLoading />);
stream.done(<Invoice {...input} />);
return JSON.stringify(input, null);
},
Expand Down
2 changes: 1 addition & 1 deletion ai/tools/weather.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ export const weatherTool = new DynamicStructuredTool({
"A tool to fetch the current weather, given a city and state. If the city/state is not provided, ask the user for both the city and state.",
schema: weatherSchema,
func: async (input, config) => {
const stream = createRunnableUI(config, <CurrentWeatherLoading />);
const stream = await createRunnableUI(config, <CurrentWeatherLoading />);
const data = await weatherData(input);
stream.done(<CurrentWeather {...data} />);
return JSON.stringify(data, null);
Expand Down
7 changes: 5 additions & 2 deletions components/prebuilt/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,13 @@ export default function Chat() {
setHistory((prev) => [
...prev,
["user", input],
["assistant", `Tool result: ${JSON.stringify(lastEvent["invokeTools"]["toolResult"], null)}`],
[
"assistant",
`Tool result: ${JSON.stringify(lastEvent["invokeTools"]["toolResult"], null)}`,
],
]);
} else {
console.log("ELSE!", lastEvent)
console.log("ELSE!", lastEvent);
}
}
})();
Expand Down
57 changes: 25 additions & 32 deletions utils/server.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ import "server-only";

import { ReactNode, isValidElement } from "react";
import { createStreamableUI, createStreamableValue } from "ai/rsc";
import { Runnable } from "@langchain/core/runnables";
import {
Runnable,
RunnableConfig,
RunnableLambda,
} from "@langchain/core/runnables";
import {
CallbackManagerForToolRun,
CallbackManagerForRetrieverRun,
Expand All @@ -18,6 +22,8 @@ import { AIProvider } from "./client";
import { AIMessage } from "../ai/message";
import { CompiledStateGraph } from "@langchain/langgraph";

const STREAM_UI_RUN_NAME = "stream_ui_lambda";

/**
* Executes `streamEvents` method on a runnable
* and converts the generator to a RSC friendly stream
Expand Down Expand Up @@ -47,10 +53,17 @@ export function streamRunnableUI<RunInput, RunOutput>(
).streamEvents(inputs, {
version: "v1",
})) {
if (
streamEvent.name === STREAM_UI_RUN_NAME &&
streamEvent.event === "on_chain_end"
) {
if (isValidElement(streamEvent.data.output.value)) {
ui.append(streamEvent.data.output.value);
}
}
const [kind, type] = streamEvent.event.split("_").slice(1);
if (type === "stream" && kind !== "chain") {
const chunk = streamEvent.data.chunk;

if (isValidElement(chunk)) {
ui.append(chunk);
} else if ("text" in chunk && typeof chunk.text === "string") {
Expand Down Expand Up @@ -88,44 +101,24 @@ export function streamRunnableUI<RunInput, RunOutput>(
* @param initialValue Initial React node to be sent to the client
* @returns Vercel AI RSC compatible streamable UI
*/
export const createRunnableUI = (
export const createRunnableUI = async (
config:
| CallbackManagerForToolRun
| CallbackManagerForRetrieverRun
| CallbackManagerForChainRun
| CallbackManagerForLLMRun
| undefined,
initialValue?: React.ReactNode,
): ReturnType<typeof createStreamableUI> => {
): Promise<ReturnType<typeof createStreamableUI>> => {
if (!config) throw new Error("No config provided");

const logStreamTracer = config.handlers.find(
(i): i is LogStreamCallbackHandler => i.name === "log_stream_tracer",
);

const ui = createStreamableUI(initialValue);

if (!logStreamTracer) throw new Error("No log stream tracer found");
// @ts-expect-error Private field
const runName = logStreamTracer.keyMapByRunId[config.runId];
if (!runName) {
console.log("No name found for", config.runId);
throw new Error("No run name found");
}

logStreamTracer.writer.write(
new RunLogPatch({
ops: [
{
op: "add",
path: `/logs/${runName}/streamed_output/-`,
value: ui.value,
},
],
}),
);

return ui;
const lambda = RunnableLambda.from(
(init: React.ReactNode, config?: RunnableConfig) => {
const ui = createStreamableUI(init);
return ui;
},
).withConfig({ runName: STREAM_UI_RUN_NAME });

return lambda.invoke(initialValue, { callbacks: config.getChild() });
};

/**
Expand Down

0 comments on commit fc7c860

Please sign in to comment.