Skip to content

Commit

Permalink
New Feat: Added Retriever Callbacks (langchain-ai#1990)
Browse files Browse the repository at this point in the history
* init retriever callbacks

* fixed a bunch of namings

* fixed lc_namespaces

* fixed parameters for vectorstore function calls to add callbacks

* changed secrets

* fixed lc serializable values

* fixed lc_namespace in BaseRetriever

* fixed hyde lc_namespace

* fix namespaces

* added tests, added retriever callbacks to tracer.

* added clients to serialization. prolly handle later.

* added auth bearer to lc_secrets in base remote retriever

* update import_type

* fixed linting

* Adds retriever callback handlers to LangChain tracer

* fixed naming, delete comment

* changed time_weighted getSalientDocuments parameter type from callbacks to run parameter. Changed base retriever _getRelevantDocuments to regular function that throws not implemented

* Remove unnecessary overrides, change comment

* Tag vector store retriever callbacks with vector store type

* Remove stray .only

* Fix build

---------

Co-authored-by: jacoblee93 <[email protected]>
  • Loading branch information
ppramesi and jacoblee93 authored Jul 19, 2023
1 parent 4b3111b commit 460e110
Show file tree
Hide file tree
Showing 65 changed files with 977 additions and 99 deletions.
3 changes: 3 additions & 0 deletions langchain/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,9 @@ schema/output_parser.d.ts
schema/query_constructor.cjs
schema/query_constructor.js
schema/query_constructor.d.ts
schema/retriever.cjs
schema/retriever.js
schema/retriever.d.ts
sql_db.cjs
sql_db.js
sql_db.d.ts
Expand Down
8 changes: 8 additions & 0 deletions langchain/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,9 @@
"schema/query_constructor.cjs",
"schema/query_constructor.js",
"schema/query_constructor.d.ts",
"schema/retriever.cjs",
"schema/retriever.js",
"schema/retriever.d.ts",
"sql_db.cjs",
"sql_db.js",
"sql_db.d.ts",
Expand Down Expand Up @@ -1423,6 +1426,11 @@
"import": "./schema/query_constructor.js",
"require": "./schema/query_constructor.cjs"
},
"./schema/retriever": {
"types": "./schema/retriever.d.ts",
"import": "./schema/retriever.js",
"require": "./schema/retriever.cjs"
},
"./sql_db": {
"types": "./sql_db.d.ts",
"import": "./sql_db.js",
Expand Down
1 change: 1 addition & 0 deletions langchain/scripts/create-entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ const entrypoints = {
schema: "schema/index",
"schema/output_parser": "schema/output_parser",
"schema/query_constructor": "schema/query_constructor",
"schema/retriever": "schema/retriever",
// sql_db
sql_db: "sql_db",
// callbacks
Expand Down
28 changes: 28 additions & 0 deletions langchain/src/callbacks/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {
SerializedNotImplemented,
} from "../load/serializable.js";
import { SerializedFields } from "../load/map_keys.js";
import { Document } from "../document.js";

// eslint-disable-next-line @typescript-eslint/no-explicit-any
type Error = any;
Expand All @@ -20,6 +21,7 @@ export interface BaseCallbackHandlerInput {
ignoreLLM?: boolean;
ignoreChain?: boolean;
ignoreAgent?: boolean;
ignoreRetriever?: boolean;
}

export interface NewTokenIndices {
Expand Down Expand Up @@ -187,6 +189,29 @@ abstract class BaseCallbackHandlerMethodsClass {
parentRunId?: string,
tags?: string[]
): Promise<void> | void;

handleRetrieverStart?(
retriever: Serialized,
query: string,
runId: string,
parentRunId?: string,
tags?: string[],
metadata?: Record<string, unknown>
): Promise<void> | void;

handleRetrieverEnd?(
documents: Document[],
runId: string,
parentRunId?: string,
tags?: string[]
): Promise<void> | void;

handleRetrieverError?(
err: Error,
runId: string,
parentRunId?: string,
tags?: string[]
): Promise<void> | void;
}

/**
Expand Down Expand Up @@ -232,6 +257,8 @@ export abstract class BaseCallbackHandler

ignoreAgent = false;

ignoreRetriever = false;

awaitHandlers =
typeof process !== "undefined"
? // eslint-disable-next-line no-process-env
Expand All @@ -245,6 +272,7 @@ export abstract class BaseCallbackHandler
this.ignoreLLM = input.ignoreLLM ?? this.ignoreLLM;
this.ignoreChain = input.ignoreChain ?? this.ignoreChain;
this.ignoreAgent = input.ignoreAgent ?? this.ignoreAgent;
this.ignoreRetriever = input.ignoreRetriever ?? this.ignoreRetriever;
}
}

Expand Down
37 changes: 37 additions & 0 deletions langchain/src/callbacks/handlers/console.ts
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,43 @@ export class ConsoleCallbackHandler extends BaseTracer {
);
}

onRetrieverStart(run: Run) {
const crumbs = this.getBreadcrumbs(run);
console.log(
`${wrap(
color.green,
"[retriever/start]"
)} [${crumbs}] Entering Retriever run with input: ${tryJsonStringify(
run.inputs,
"[inputs]"
)}`
);
}

onRetrieverEnd(run: Run) {
const crumbs = this.getBreadcrumbs(run);
console.log(
`${wrap(color.cyan, "[retriever/end]")} [${crumbs}] [${elapsed(
run
)}] Exiting Retriever run with output: ${tryJsonStringify(
run.outputs,
"[outputs]"
)}`
);
}

onRetrieverError(run: Run) {
const crumbs = this.getBreadcrumbs(run);
console.log(
`${wrap(color.red, "[retriever/error]")} [${crumbs}] [${elapsed(
run
)}] Retriever run errored with error: ${tryJsonStringify(
run.error,
"[error]"
)}`
);
}

onAgentAction(run: Run) {
const agentRun = run as AgentRun;
const crumbs = this.getBreadcrumbs(run);
Expand Down
75 changes: 75 additions & 0 deletions langchain/src/callbacks/handlers/tracer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import {
BaseCallbackHandlerInput,
NewTokenIndices,
} from "../base.js";
import { Document } from "../../document.js";

export type RunType = "llm" | "chain" | "tool";

Expand Down Expand Up @@ -353,6 +354,74 @@ export abstract class BaseTracer extends BaseCallbackHandler {
await this.onAgentEnd?.(run);
}

async handleRetrieverStart(
retriever: Serialized,
query: string,
runId: string,
parentRunId?: string,
tags?: string[],
metadata?: KVMap
): Promise<void> {
const execution_order = this._getExecutionOrder(parentRunId);
const start_time = Date.now();
const run: Run = {
id: runId,
name: retriever.id[retriever.id.length - 1],
parent_run_id: parentRunId,
start_time,
serialized: retriever,
events: [
{
name: "start",
time: start_time,
},
],
inputs: { query },
execution_order,
child_execution_order: execution_order,
run_type: "retriever",
child_runs: [],
extra: metadata ? { metadata } : {},
tags: tags || [],
};

this._startTrace(run);
await this.onRetrieverStart?.(run);
}

async handleRetrieverEnd(
documents: Document<Record<string, unknown>>[],
runId: string
): Promise<void> {
const run = this.runMap.get(runId);
if (!run || run?.run_type !== "retriever") {
throw new Error("No retriever run to end");
}
run.end_time = Date.now();
run.outputs = { documents };
run.events.push({
name: "end",
time: run.end_time,
});
await this.onRetrieverEnd?.(run);
await this._endTrace(run);
}

async handleRetrieverError(error: Error, runId: string): Promise<void> {
const run = this.runMap.get(runId);
if (!run || run?.run_type !== "retriever") {
throw new Error("No retriever run to end");
}
run.end_time = Date.now();
run.error = error.message;
run.events.push({
name: "error",
time: run.end_time,
});
await this.onRetrieverError?.(run);
await this._endTrace(run);
}

async handleText(text: string, runId: string): Promise<void> {
const run = this.runMap.get(runId);
if (!run || run?.run_type !== "chain") {
Expand Down Expand Up @@ -407,6 +476,12 @@ export abstract class BaseTracer extends BaseCallbackHandler {

onAgentEnd?(run: Run): void | Promise<void>;

onRetrieverStart?(run: Run): void | Promise<void>;

onRetrieverEnd?(run: Run): void | Promise<void>;

onRetrieverError?(run: Run): void | Promise<void>;

onText?(run: Run): void | Promise<void>;

onLLMNewToken?(run: Run): void | Promise<void>;
Expand Down
12 changes: 12 additions & 0 deletions langchain/src/callbacks/handlers/tracer_langchain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,18 @@ export class LangChainTracer
await this.client.updateRun(run.id, runUpdate);
}

async onRetrieverStart(run: Run): Promise<void> {
await this._persistRunSingle(run);
}

async onRetrieverEnd(run: Run): Promise<void> {
await this._updateRunSingle(run);
}

async onRetrieverError(run: Run): Promise<void> {
await this._updateRunSingle(run);
}

async onLLMStart(run: Run): Promise<void> {
await this._persistRunSingle(run);
}
Expand Down
104 changes: 104 additions & 0 deletions langchain/src/callbacks/manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import {
} from "./handlers/tracer_langchain.js";
import { consumeCallback } from "./promises.js";
import { Serialized } from "../load/serializable.js";
import { Document } from "../document.js";

type BaseCallbackManagerMethods = {
[K in keyof CallbackHandlerMethods]?: (
Expand Down Expand Up @@ -118,6 +119,69 @@ class BaseRunManager {
}
}

export class CallbackManagerForRetrieverRun
extends BaseRunManager
implements BaseCallbackManagerMethods
{
getChild(tag?: string): CallbackManager {
// eslint-disable-next-line @typescript-eslint/no-use-before-define
const manager = new CallbackManager(this.runId);
manager.setHandlers(this.inheritableHandlers);
manager.addTags(this.inheritableTags);
manager.addMetadata(this.inheritableMetadata);
if (tag) {
manager.addTags([tag], false);
}
return manager;
}

async handleRetrieverEnd(documents: Document[]): Promise<void> {
await Promise.all(
this.handlers.map((handler) =>
consumeCallback(async () => {
if (!handler.ignoreRetriever) {
try {
await handler.handleRetrieverEnd?.(
documents,
this.runId,
this._parentRunId,
this.tags
);
} catch (err) {
console.error(
`Error in handler ${handler.constructor.name}, handleRetriever`
);
}
}
}, handler.awaitHandlers)
)
);
}

async handleRetrieverError(err: Error | unknown): Promise<void> {
await Promise.all(
this.handlers.map((handler) =>
consumeCallback(async () => {
if (!handler.ignoreRetriever) {
try {
await handler.handleRetrieverError?.(
err,
this.runId,
this._parentRunId,
this.tags
);
} catch (error) {
console.error(
`Error in handler ${handler.constructor.name}, handleRetrieverError: ${error}`
);
}
}
}, handler.awaitHandlers)
)
);
}
}

export class CallbackManagerForLLMRun
extends BaseRunManager
implements BaseCallbackManagerMethods
Expand Down Expand Up @@ -584,6 +648,46 @@ export class CallbackManager
);
}

async handleRetrieverStart(
retriever: Serialized,
query: string,
runId: string = uuidv4(),
_parentRunId: string | undefined = undefined
): Promise<CallbackManagerForRetrieverRun> {
await Promise.all(
this.handlers.map((handler) =>
consumeCallback(async () => {
if (!handler.ignoreRetriever) {
try {
await handler.handleRetrieverStart?.(
retriever,
query,
runId,
this._parentRunId,
this.tags,
this.metadata
);
} catch (err) {
console.error(
`Error in handler ${handler.constructor.name}, handleRetrieverStart: ${err}`
);
}
}
}, handler.awaitHandlers)
)
);
return new CallbackManagerForRetrieverRun(
runId,
this.handlers,
this.inheritableHandlers,
this.tags,
this.inheritableTags,
this.metadata,
this.inheritableMetadata,
this._parentRunId
);
}

addHandler(handler: BaseCallbackHandler, inherit = true): void {
this.handlers.push(handler);
if (inherit) {
Expand Down
Loading

0 comments on commit 460e110

Please sign in to comment.