Skip to content

Commit

Permalink
Implement stream and transform in RunnableMap
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Dec 26, 2023
1 parent 23202d6 commit 2cdb102
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 9 deletions.
76 changes: 68 additions & 8 deletions langchain-core/src/runnables/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ import {
import { Serializable } from "../load/serializable.js";
import {
IterableReadableStream,
concat,
type IterableReadableStreamInterface,
atee,
} from "../utils/stream.js";
import {
RunnableConfig,
Expand Down Expand Up @@ -428,13 +430,13 @@ export abstract class Runnable<
const callbackManager_ = await getCallbackMangerForConfig(options);
const runManager = await callbackManager_?.handleChainStart(
this.toJSON(),
{ input: "" },
undefined,
options?.runType,
undefined,
undefined,
options?.runName
);
{ input: "" },
undefined,
options?.runType,
undefined,
undefined,
options?.runName
);
async function* wrapInputForTracing() {
for await (const chunk of inputGenerator) {
if (finalInputSupported) {
Expand Down Expand Up @@ -1485,7 +1487,7 @@ export class RunnableMap<
Object.entries(this.steps).map(async ([key, runnable]) => {
output[key] = await runnable.invoke(
input,
this._patchConfig(options, runManager?.getChild(key))
this._patchConfig(options, runManager?.getChild(`map:key:${key}`))
);
})
);
Expand All @@ -1496,6 +1498,64 @@ export class RunnableMap<
await runManager?.handleChainEnd(output);
return output as RunOutput;
}

async *_transform(
generator: AsyncGenerator<RunInput>,
runManager?: CallbackManagerForChainRun,
options?: Partial<RunnableConfig>
): AsyncGenerator<RunOutput> {
// shallow copy steps to ignore changes while iterating
const steps = { ...this.steps };
// each step gets a copy of the input iterator
const inputCopies = atee(generator, Object.keys(steps).length);
// start the first iteration of each output iterator
const tasks = new Map(
Object.entries(steps).map(([key, runnable], i) => {
const gen = runnable.transform(
inputCopies[i],
this._patchConfig(options, runManager?.getChild(`map:key:${key}`))
);
return [key, gen.next().then((result) => ({ key, gen, result }))];
})
);
// yield chunks as they become available,
// starting new iterations as needed,
// until all iterators are done
while (tasks.size) {
const { key, result, gen } = await Promise.race(tasks.values());
tasks.delete(key);
if (!result.done) {
yield { [key]: result.value } as unknown as RunOutput;
tasks.set(
key,
gen.next().then((result) => ({ key, gen, result }))
);
}
}
}

transform(
generator: AsyncGenerator<RunInput, any, unknown>,
options?: Partial<RunnableConfig>
): AsyncGenerator<RunOutput, any, unknown> {
return this._transformStreamWithConfig(
generator,
this._transform.bind(this),
options
);
}

async stream(
input: RunInput,
options?: Partial<RunnableConfig> | undefined
): Promise<IterableReadableStream<RunOutput>> {
async function* generator() {
yield input;
}
return IterableReadableStream.fromAsyncGenerator(
this.transform(generator(), options)
);
}
}

/**
Expand Down
41 changes: 41 additions & 0 deletions langchain-core/src/runnables/tests/runnable_map.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ import {
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
} from "../../prompts/chat.js";
import { concat } from "../../utils/stream.js";
import {
FakeLLM,
FakeChatModel,
FakeRetriever,
FakeStreamingLLM,
} from "../../utils/testing/index.js";
import { RunnableSequence, RunnableMap } from "../base.js";
import { RunnablePassthrough } from "../passthrough.js";
Expand Down Expand Up @@ -103,3 +105,42 @@ test("Should not allow improper outputs from a map into the next item in a seque
const runnable = map.pipe(new FakeLLM({}));
console.log(runnable);
});

test("Should stream chunks from each step as they are produced", async () => {
const prompt = ChatPromptTemplate.fromMessages([
["system", "You are a nice assistant."],
"{question}",
]);

const chat = new FakeChatModel({});

const llm = new FakeStreamingLLM({ sleep: 0 });

const chain = RunnableSequence.from([
prompt,
RunnableMap.from({
passthrough: new RunnablePassthrough(),
chat,
llm,
}),
]);

const stream = await chain.stream({ question: "What is your name?" });

const chunks = [];

for await (const chunk of stream) {
chunks.push(chunk);
}

expect(chunks.length).toBeGreaterThan(3);
expect(chunks.reduce(concat)).toEqual(
await chain.invoke({ question: "What is your name?" })
);

const chainWithSelect = chain.pipe((output) => output.llm);

expect(await chainWithSelect.invoke({ question: "What is your name?" }))
.toEqual(`System: You are a nice assistant.
Human: What is your name?`);
});
59 changes: 59 additions & 0 deletions langchain-core/src/utils/stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,62 @@ export class IterableReadableStream<T>
});
}
}

export function atee<T>(
iter: AsyncGenerator<T>,
length = 2
): AsyncGenerator<T>[] {
const buffers = Array.from(
{ length },
() => [] as Array<IteratorResult<T> | IteratorReturnResult<T>>
);
return buffers.map(async function* makeIter(buffer) {
while (true) {
if (buffer.length === 0) {
const result = await iter.next();
for (const buffer of buffers) {
buffer.push(result);
}
} else if (buffer[0].done) {
return;
} else {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
yield buffer.shift()!.value;
}
}
});
}

export function concat<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
T extends Array<any> | string | number | Record<string, any> | any
>(first: T, second: T): T {
if (Array.isArray(first) && Array.isArray(second)) {
return first.concat(second) as T;
} else if (typeof first === "string" && typeof second === "string") {
return (first + second) as T;
} else if (typeof first === "number" && typeof second === "number") {
return (first + second) as T;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} else if (
"concat" in (first as any) &&
typeof (first as any).concat === "function"
) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return (first as any).concat(second) as T;
} else if (typeof first === "object" && typeof second === "object") {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const chunk = { ...first } as Record<string, any>;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
for (const [key, value] of Object.entries(second as Record<string, any>)) {
if (key in chunk) {
chunk[key] = concat(chunk[key], value);
} else {
chunk[key] = value;
}
}
return chunk as T;
} else {
throw new Error(`Cannot concat ${typeof first} and ${typeof second}`);
}
}
9 changes: 8 additions & 1 deletion langchain-core/src/utils/testing/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ export class FakeLLM extends LLM {
}

export class FakeStreamingLLM extends LLM {
sleep?: number = 50;

constructor(fields: { sleep?: number } & BaseLLMParams) {
super(fields);
this.sleep = fields.sleep ?? this.sleep;
}

_llmType() {
return "fake";
}
Expand All @@ -109,7 +116,7 @@ export class FakeStreamingLLM extends LLM {

async *_streamResponseChunks(input: string) {
for (const c of input) {
await new Promise((resolve) => setTimeout(resolve, 50));
await new Promise((resolve) => setTimeout(resolve, this.sleep));
yield { text: c, generationInfo: {} } as GenerationChunk;
}
}
Expand Down

0 comments on commit 2cdb102

Please sign in to comment.