diff --git a/langchain/chains/base.ts b/langchain/chains/base.ts index 69dd5f7f71db..80e1bf0aa159 100644 --- a/langchain/chains/base.ts +++ b/langchain/chains/base.ts @@ -1,4 +1,6 @@ import { LLMChain, StuffDocumentsChain, VectorDBQAChain } from "./index"; +import { BaseMemory } from "../memory"; + // eslint-disable-next-line @typescript-eslint/no-explicit-any export type ChainValues = Record; // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -10,16 +12,33 @@ export type SerializedBaseChain = ReturnType< InstanceType<(typeof chainClasses)[number]>["serialize"] >; -export abstract class BaseChain { +export interface ChainInputs { + memory?: BaseMemory; +} + +export abstract class BaseChain implements ChainInputs { + memory?: BaseMemory; + abstract _call(values: ChainValues): Promise; abstract _chainType(): string; abstract serialize(): SerializedBaseChain; - call(values: ChainValues): Promise { + async call(values: ChainValues): Promise { + const fullValues = structuredClone(values); + if (!(this.memory == null)) { + const newValues = await this.memory.loadMemoryVariables(values); + for (const [key, value] of Object.entries(newValues)) { + fullValues[key] = value; + } + } // TODO(sean) add callback support - return this._call(values); + const outputValues = this._call(fullValues); + if (!(this.memory == null)) { + this.memory.saveContext(values, outputValues); + } + return outputValues; } apply(inputs: ChainValues[]): ChainValues[] { diff --git a/langchain/memory/base.ts b/langchain/memory/base.ts new file mode 100644 index 000000000000..c1c1d1951ed7 --- /dev/null +++ b/langchain/memory/base.ts @@ -0,0 +1,15 @@ +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export type InputValues = Record; +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export type OutputValues = Record; +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export type MemoryVariables = Record; + +export abstract class BaseMemory { + abstract loadMemoryVariables(values: InputValues): Promise; + + abstract saveContext( + inputValues: InputValues, + OutputValues: Promise + ): Promise; +} diff --git a/langchain/memory/buffer_memory.ts b/langchain/memory/buffer_memory.ts new file mode 100644 index 000000000000..7a64f8e8bb78 --- /dev/null +++ b/langchain/memory/buffer_memory.ts @@ -0,0 +1,50 @@ +import { BaseMemory, InputValues, MemoryVariables, OutputValues } from "./base"; + +export interface BufferMemoryInput { + humanPrefix: string; + aiPrefix: string; + memoryKey: string; +} + +const getInputValue = (inputValues: InputValues) => { + const keys = Object.keys(inputValues); + if (keys.length === 1) { + return inputValues[keys[0]]; + } + throw new Error( + "input values have multiple keys, memory only supported when one key currently" + ); +}; + +export class BufferMemory extends BaseMemory implements BufferMemoryInput { + humanPrefix = "Human"; + + aiPrefix = "AI"; + + memoryKey = "history"; + + buffer = ""; + + constructor(fields?: Partial) { + super(); + this.humanPrefix = fields?.humanPrefix ?? this.humanPrefix; + this.aiPrefix = fields?.aiPrefix ?? this.aiPrefix; + this.memoryKey = fields?.memoryKey ?? this.memoryKey; + } + + async loadMemoryVariables(_values: InputValues): Promise { + const result = { [this.memoryKey]: this.buffer }; + return result; + } + + async saveContext( + inputValues: InputValues, + outputValues: Promise + ): Promise { + const values = await outputValues; + const human = `${this.humanPrefix}: ${getInputValue(inputValues)}`; + const ai = `${this.aiPrefix}: ${getInputValue(values)}`; + const newlines = [human, ai]; + this.buffer += `\n${newlines.join("\n")}`; + } +} diff --git a/langchain/memory/index.ts b/langchain/memory/index.ts new file mode 100644 index 000000000000..c9a4ec3402f8 --- /dev/null +++ b/langchain/memory/index.ts @@ -0,0 +1,2 @@ +export { BufferMemory } from "./buffer_memory"; +export { BaseMemory } from "./base"; diff --git a/langchain/memory/tests/buffer_memory.test.ts b/langchain/memory/tests/buffer_memory.test.ts new file mode 100644 index 000000000000..3cdac331b587 --- /dev/null +++ b/langchain/memory/tests/buffer_memory.test.ts @@ -0,0 +1,17 @@ +import { test, expect } from "@jest/globals"; +import { BufferMemory } from "../buffer_memory"; +import { OutputValues } from "../base"; + +test("Test buffer memory", async () => { + const memory = new BufferMemory(); + const result1 = await memory.loadMemoryVariables({}); + expect(result1).toStrictEqual({ history: "" }); + + const result = new Promise((resolve, _reject) => { + resolve({ bar: "foo" }); + }); + await memory.saveContext({ foo: "bar" }, result); + const expectedString = "\nHuman: bar\nAI: foo"; + const result2 = await memory.loadMemoryVariables({}); + expect(result2).toStrictEqual({ history: expectedString }); +}); diff --git a/langchain/prompt/template.ts b/langchain/prompt/template.ts index 14f25c05b0db..c790b27af843 100644 --- a/langchain/prompt/template.ts +++ b/langchain/prompt/template.ts @@ -53,7 +53,6 @@ export const parseFString = (template: string): ParsedFStringNode[] => { i = next < 0 ? chars.length : next; } } - return nodes; };