Skip to content

Commit

Permalink
add memory class (langchain-ai#14)
Browse files Browse the repository at this point in the history
* vector db qa chain

* cr

* add memory class

* cr

* conversation chain (langchain-ai#15)
  • Loading branch information
hwchase17 authored Feb 16, 2023
1 parent 57448c6 commit 33df92c
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 4 deletions.
25 changes: 22 additions & 3 deletions langchain/chains/base.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import { LLMChain, StuffDocumentsChain, VectorDBQAChain } from "./index";
import { BaseMemory } from "../memory";

// eslint-disable-next-line @typescript-eslint/no-explicit-any
export type ChainValues = Record<string, any>;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
Expand All @@ -10,16 +12,33 @@ export type SerializedBaseChain = ReturnType<
InstanceType<(typeof chainClasses)[number]>["serialize"]
>;

export abstract class BaseChain {
export interface ChainInputs {
memory?: BaseMemory;
}

export abstract class BaseChain implements ChainInputs {
memory?: BaseMemory;

abstract _call(values: ChainValues): Promise<ChainValues>;

abstract _chainType(): string;

abstract serialize(): SerializedBaseChain;

call(values: ChainValues): Promise<ChainValues> {
async call(values: ChainValues): Promise<ChainValues> {
const fullValues = structuredClone(values);
if (!(this.memory == null)) {
const newValues = await this.memory.loadMemoryVariables(values);
for (const [key, value] of Object.entries(newValues)) {
fullValues[key] = value;
}
}
// TODO(sean) add callback support
return this._call(values);
const outputValues = this._call(fullValues);
if (!(this.memory == null)) {
this.memory.saveContext(values, outputValues);
}
return outputValues;
}

apply(inputs: ChainValues[]): ChainValues[] {
Expand Down
15 changes: 15 additions & 0 deletions langchain/memory/base.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export type InputValues = Record<string, any>;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export type OutputValues = Record<string, any>;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export type MemoryVariables = Record<string, any>;

export abstract class BaseMemory {
abstract loadMemoryVariables(values: InputValues): Promise<MemoryVariables>;

abstract saveContext(
inputValues: InputValues,
OutputValues: Promise<OutputValues>
): Promise<void>;
}
50 changes: 50 additions & 0 deletions langchain/memory/buffer_memory.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import { BaseMemory, InputValues, MemoryVariables, OutputValues } from "./base";

export interface BufferMemoryInput {
humanPrefix: string;
aiPrefix: string;
memoryKey: string;
}

const getInputValue = (inputValues: InputValues) => {
const keys = Object.keys(inputValues);
if (keys.length === 1) {
return inputValues[keys[0]];
}
throw new Error(
"input values have multiple keys, memory only supported when one key currently"
);
};

export class BufferMemory extends BaseMemory implements BufferMemoryInput {
humanPrefix = "Human";

aiPrefix = "AI";

memoryKey = "history";

buffer = "";

constructor(fields?: Partial<BufferMemoryInput>) {
super();
this.humanPrefix = fields?.humanPrefix ?? this.humanPrefix;
this.aiPrefix = fields?.aiPrefix ?? this.aiPrefix;
this.memoryKey = fields?.memoryKey ?? this.memoryKey;
}

async loadMemoryVariables(_values: InputValues): Promise<MemoryVariables> {
const result = { [this.memoryKey]: this.buffer };
return result;
}

async saveContext(
inputValues: InputValues,
outputValues: Promise<OutputValues>
): Promise<void> {
const values = await outputValues;
const human = `${this.humanPrefix}: ${getInputValue(inputValues)}`;
const ai = `${this.aiPrefix}: ${getInputValue(values)}`;
const newlines = [human, ai];
this.buffer += `\n${newlines.join("\n")}`;
}
}
2 changes: 2 additions & 0 deletions langchain/memory/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
export { BufferMemory } from "./buffer_memory";
export { BaseMemory } from "./base";
17 changes: 17 additions & 0 deletions langchain/memory/tests/buffer_memory.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import { test, expect } from "@jest/globals";
import { BufferMemory } from "../buffer_memory";
import { OutputValues } from "../base";

test("Test buffer memory", async () => {
const memory = new BufferMemory();
const result1 = await memory.loadMemoryVariables({});
expect(result1).toStrictEqual({ history: "" });

const result = new Promise<OutputValues>((resolve, _reject) => {
resolve({ bar: "foo" });
});
await memory.saveContext({ foo: "bar" }, result);
const expectedString = "\nHuman: bar\nAI: foo";
const result2 = await memory.loadMemoryVariables({});
expect(result2).toStrictEqual({ history: expectedString });
});
1 change: 0 additions & 1 deletion langchain/prompt/template.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ export const parseFString = (template: string): ParsedFStringNode[] => {
i = next < 0 ? chars.length : next;
}
}

return nodes;
};

Expand Down

0 comments on commit 33df92c

Please sign in to comment.