Skip to content

Commit

Permalink
Add recursion limit for runnable lambda, various fixes for config, types
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Dec 26, 2023
1 parent ca93246 commit 23202d6
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 50 deletions.
7 changes: 0 additions & 7 deletions langchain-core/src/callbacks/manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,6 @@ export interface BaseCallbackConfig {
* Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
*/
callbacks?: Callbacks;

/**
* Runtime values for attributes previously made configurable on this Runnable,
* or sub-Runnables.
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
configurable?: Record<string, any>;
}

export function parseCallbackConfigArg(
Expand Down
66 changes: 32 additions & 34 deletions langchain-core/src/runnables/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import pRetry from "p-retry";
import {
CallbackManager,
CallbackManagerForChainRun,
BaseCallbackConfig,
} from "../callbacks/manager.js";
import {
LogStreamCallbackHandler,
Expand Down Expand Up @@ -59,12 +58,6 @@ export interface RunnableInterface<
batchOptions?: RunnableBatchOptions
): Promise<(RunOutput | Error)[]>;

batch(
inputs: RunInput[],
options?: Partial<CallOptions> | Partial<CallOptions>[],
batchOptions?: RunnableBatchOptions
): Promise<(RunOutput | Error)[]>;

stream(
input: RunInput,
options?: Partial<CallOptions>
Expand Down Expand Up @@ -433,30 +426,24 @@ export abstract class Runnable<
let finalOutputSupported = true;

const callbackManager_ = await getCallbackMangerForConfig(options);
let runManager: CallbackManagerForChainRun | undefined;
const serializedRepresentation = this.toJSON();
async function* wrapInputForTracing() {
for await (const chunk of inputGenerator) {
if (!runManager) {
// Start the run manager AFTER the iterator starts to preserve
// tracing order
runManager = await callbackManager_?.handleChainStart(
serializedRepresentation,
const runManager = await callbackManager_?.handleChainStart(
this.toJSON(),
{ input: "" },
undefined,
options?.runType,
undefined,
undefined,
options?.runName
);
}
async function* wrapInputForTracing() {
for await (const chunk of inputGenerator) {
if (finalInputSupported) {
if (finalInput === undefined) {
finalInput = chunk;
} else {
try {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
finalInput = (finalInput as any).concat(chunk);
finalInput = concat(finalInput, chunk as any);
} catch {
finalInput = undefined;
finalInputSupported = false;
Expand All @@ -482,7 +469,7 @@ export abstract class Runnable<
} else {
try {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
finalOutput = (finalOutput as any).concat(chunk);
finalOutput = concat(finalOutput, chunk as any);
} catch {
finalOutput = undefined;
finalOutputSupported = false;
Expand All @@ -507,7 +494,8 @@ export abstract class Runnable<

_patchConfig(
config: Partial<CallOptions> = {},
callbackManager: CallbackManager | undefined = undefined
callbackManager: CallbackManager | undefined = undefined,
recursionLimit: number | undefined = undefined
): Partial<CallOptions> {
const newConfig = { ...config };
if (callbackManager !== undefined) {
Expand All @@ -518,6 +506,9 @@ export abstract class Runnable<
delete newConfig.runName;
return { ...newConfig, callbacks: callbackManager };
}
if (recursionLimit !== undefined) {
newConfig.recursionLimit = recursionLimit;
}
return newConfig;
}

Expand Down Expand Up @@ -556,7 +547,7 @@ export abstract class Runnable<
// Make a best effort to gather, for any type that supports concat.
// This method should throw an error if gathering fails.
// eslint-disable-next-line @typescript-eslint/no-explicit-any
finalChunk = (finalChunk as any).concat(chunk);
finalChunk = concat(finalChunk, chunk as any);
}
}
yield* this._streamIterator(finalChunk, options);
Expand Down Expand Up @@ -670,7 +661,7 @@ export abstract class Runnable<
export type RunnableBindingArgs<
RunInput,
RunOutput,
CallOptions extends RunnableConfig
CallOptions extends RunnableConfig = RunnableConfig
> = {
bound: Runnable<RunInput, RunOutput, CallOptions>;
kwargs?: Partial<CallOptions>;
Expand All @@ -684,7 +675,7 @@ export type RunnableBindingArgs<
export class RunnableBinding<
RunInput,
RunOutput,
CallOptions extends RunnableConfig
CallOptions extends RunnableConfig = RunnableConfig
> extends Runnable<RunInput, RunOutput, CallOptions> {
static lc_name() {
return "RunnableBinding";
Expand Down Expand Up @@ -892,7 +883,7 @@ export class RunnableBinding<
export class RunnableEach<
RunInputItem,
RunOutputItem,
CallOptions extends BaseCallbackConfig
CallOptions extends RunnableConfig
> extends Runnable<RunInputItem[], RunOutputItem[], CallOptions> {
static lc_name() {
return "RunnableEach";
Expand Down Expand Up @@ -1360,7 +1351,7 @@ export class RunnableSequence<
} else {
try {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
finalOutput = (finalOutput as any).concat(chunk);
finalOutput = concat(finalOutput, chunk as any);
} catch (e) {
finalOutput = undefined;
concatSupported = false;
Expand Down Expand Up @@ -1473,7 +1464,7 @@ export class RunnableMap<

async invoke(
input: RunInput,
options?: Partial<BaseCallbackConfig>
options?: Partial<RunnableConfig>
): Promise<RunOutput> {
const callbackManager_ = await getCallbackMangerForConfig(options);
const runManager = await callbackManager_?.handleChainStart(
Expand Down Expand Up @@ -1537,22 +1528,29 @@ export class RunnableLambda<RunInput, RunOutput> extends Runnable<

async _invoke(
input: RunInput,
config?: Partial<BaseCallbackConfig>,
config?: Partial<RunnableConfig>,
runManager?: CallbackManagerForChainRun
) {
let output = await this.func(input, { config });
if (output && Runnable.isRunnable(output)) {
if (config?.recursionLimit === 0) {
throw new Error("Recursion limit reached.");
}
output = await output.invoke(
input,
this._patchConfig(config, runManager?.getChild())
this._patchConfig(
config,
runManager?.getChild(),
(config?.recursionLimit ?? 25) - 1
)
);
}
return output;
}

async invoke(
input: RunInput,
options?: Partial<BaseCallbackConfig>
options?: Partial<RunnableConfig>
): Promise<RunOutput> {
return this._callWithConfig(this._invoke, input, options);
}
Expand Down Expand Up @@ -1597,7 +1595,7 @@ export class RunnableWithFallbacks<RunInput, RunOutput> extends Runnable<

async invoke(
input: RunInput,
options?: Partial<BaseCallbackConfig>
options?: Partial<RunnableConfig>
): Promise<RunOutput> {
const callbackManager_ = await CallbackManager.configure(
options?.callbacks,
Expand Down Expand Up @@ -1639,25 +1637,25 @@ export class RunnableWithFallbacks<RunInput, RunOutput> extends Runnable<

async batch(
inputs: RunInput[],
options?: Partial<BaseCallbackConfig> | Partial<BaseCallbackConfig>[],
options?: Partial<RunnableConfig> | Partial<RunnableConfig>[],
batchOptions?: RunnableBatchOptions & { returnExceptions?: false }
): Promise<RunOutput[]>;

async batch(
inputs: RunInput[],
options?: Partial<BaseCallbackConfig> | Partial<BaseCallbackConfig>[],
options?: Partial<RunnableConfig> | Partial<RunnableConfig>[],
batchOptions?: RunnableBatchOptions & { returnExceptions: true }
): Promise<(RunOutput | Error)[]>;

async batch(
inputs: RunInput[],
options?: Partial<BaseCallbackConfig> | Partial<BaseCallbackConfig>[],
options?: Partial<RunnableConfig> | Partial<RunnableConfig>[],
batchOptions?: RunnableBatchOptions
): Promise<(RunOutput | Error)[]>;

async batch(
inputs: RunInput[],
options?: Partial<BaseCallbackConfig> | Partial<BaseCallbackConfig>[],
options?: Partial<RunnableConfig> | Partial<RunnableConfig>[],
batchOptions?: RunnableBatchOptions
): Promise<(RunOutput | Error)[]> {
if (batchOptions?.returnExceptions) {
Expand Down
16 changes: 15 additions & 1 deletion langchain-core/src/runnables/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,19 @@ import {
CallbackManager,
} from "../callbacks/manager.js";

export type RunnableConfig = BaseCallbackConfig;
export interface RunnableConfig extends BaseCallbackConfig {
/**
* Runtime values for attributes previously made configurable on this Runnable,
* or sub-Runnables.
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
configurable?: Record<string, any>;

/**
* Maximum number of times a call can recurse. If not provided, defaults to 25.
*/
recursionLimit?: number;
}

export async function getCallbackMangerForConfig(config?: RunnableConfig) {
return CallbackManager.configure(
Expand All @@ -28,6 +40,8 @@ export function mergeConfigs<CallOptions extends RunnableConfig>(
copy[key] = { ...copy[key], ...options[key] };
} else if (key === "tags") {
copy[key] = (copy[key] ?? []).concat(options[key] ?? []);
} else if (key === "configurable") {
copy[key] = { ...copy[key], ...options[key] };
} else if (key === "callbacks") {
const baseCallbacks = copy.callbacks;
const providedCallbacks = options.callbacks ?? config.callbacks;
Expand Down
12 changes: 4 additions & 8 deletions langchain-core/src/runnables/history.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import { BaseCallbackConfig } from "../callbacks/manager.js";
import {
BaseChatMessageHistory,
BaseListChatMessageHistory,
Expand Down Expand Up @@ -28,10 +27,7 @@ type GetSessionHistoryCallable = (
| BaseListChatMessageHistory;

export interface RunnableWithMessageHistoryInputs<RunInput, RunOutput>
extends Omit<
RunnableBindingArgs<RunInput, RunOutput, BaseCallbackConfig>,
"bound" | "config"
> {
extends Omit<RunnableBindingArgs<RunInput, RunOutput>, "bound" | "config"> {
runnable: Runnable<RunInput, RunOutput>;
getMessageHistory: GetSessionHistoryCallable;
inputMessagesKey?: string;
Expand All @@ -43,7 +39,7 @@ export interface RunnableWithMessageHistoryInputs<RunInput, RunOutput>
export class RunnableWithMessageHistory<
RunInput,
RunOutput
> extends RunnableBinding<RunInput, RunOutput, BaseCallbackConfig> {
> extends RunnableBinding<RunInput, RunOutput> {
runnable: Runnable<RunInput, RunOutput>;

inputMessagesKey?: string;
Expand Down Expand Up @@ -151,7 +147,7 @@ export class RunnableWithMessageHistory<
return returnType;
}

async _exitHistory(run: Run, config: BaseCallbackConfig): Promise<void> {
async _exitHistory(run: Run, config: RunnableConfig): Promise<void> {
const history = config.configurable?.messageHistory;

// Get input messages
Expand All @@ -176,7 +172,7 @@ export class RunnableWithMessageHistory<
}
}

async _mergeConfig(...configs: Array<BaseCallbackConfig | undefined>) {
async _mergeConfig(...configs: Array<RunnableConfig | undefined>) {
const config = await super._mergeConfig(...configs);
// Extract sessionId
if (!config.configurable || !config.configurable.sessionId) {
Expand Down

0 comments on commit 23202d6

Please sign in to comment.