forked from langchain-ai/langchainjs
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implementation of a VectorStoreMemory (langchain-ai#452)
* Implementation of a VectorStoreMemory * Revise implementation to match py, add test * Add export * Add example * Add docs
- Loading branch information
Showing
6 changed files
with
192 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
--- | ||
hide_table_of_contents: true | ||
--- | ||
|
||
import CodeBlock from "@theme/CodeBlock"; | ||
import Example from "@examples/memory/vector_store.ts"; | ||
|
||
# VectorStore-backed Memory | ||
|
||
`VectorStoreRetrieverMemory` stores memories in a VectorDB and queries the top-K most "salient" docs every time it is called. | ||
|
||
This differs from most of the other Memory classes in that it doesn't explicitly track the order of interactions. | ||
|
||
In this case, the "docs" are previous conversation snippets. This can be useful to refer to relevant pieces of information that the AI was told earlier in the conversation. | ||
|
||
<CodeBlock language="typescript">{Example}</CodeBlock> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import { OpenAI } from "langchain/llms/openai"; | ||
import { VectorStoreRetrieverMemory } from "langchain/memory"; | ||
import { LLMChain } from "langchain/chains"; | ||
import { PromptTemplate } from "langchain/prompts"; | ||
import { MemoryVectorStore } from "langchain/vectorstores/memory"; | ||
import { OpenAIEmbeddings } from "langchain/embeddings/openai"; | ||
|
||
const vectorStore = new MemoryVectorStore(new OpenAIEmbeddings()); | ||
const memory = new VectorStoreRetrieverMemory({ | ||
// 1 is how many documents to return, you might want to return more, eg. 4 | ||
vectorStoreRetriever: vectorStore.asRetriever(1), | ||
memoryKey: "history", | ||
}); | ||
|
||
// First let's save some information to memory, as it would happen when | ||
// used inside a chain. | ||
await memory.saveContext( | ||
{ input: "My favorite food is pizza" }, | ||
{ output: "thats good to know" } | ||
); | ||
await memory.saveContext( | ||
{ input: "My favorite sport is soccer" }, | ||
{ output: "..." } | ||
); | ||
await memory.saveContext({ input: "I don't the Celtics" }, { output: "ok" }); | ||
|
||
// Now let's use the memory to retrieve the information we saved. | ||
console.log( | ||
await memory.loadMemoryVariables({ prompt: "what sport should i watch?" }) | ||
); | ||
/* | ||
{ history: 'input: My favorite sport is soccer\noutput: ...' } | ||
*/ | ||
|
||
// Now let's use it in a chain. | ||
const model = new OpenAI({ temperature: 0.9 }); | ||
const prompt = | ||
PromptTemplate.fromTemplate(`The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know. | ||
Relevant pieces of previous conversation: | ||
{history} | ||
(You do not need to use these pieces of information if not relevant) | ||
Current conversation: | ||
Human: {input} | ||
AI:`); | ||
const chain = new LLMChain({ llm: model, prompt, memory }); | ||
|
||
const res1 = await chain.call({ input: "Hi, my name is Perry, what's up?" }); | ||
console.log({ res1 }); | ||
/* | ||
{ | ||
res1: { | ||
text: " Hi Perry, I'm doing great! I'm currently exploring different topics related to artificial intelligence like natural language processing and machine learning. What about you? What have you been up to lately?" | ||
} | ||
} | ||
*/ | ||
|
||
const res2 = await chain.call({ input: "what's my favorite sport?" }); | ||
console.log({ res2 }); | ||
/* | ||
{ res2: { text: ' You said your favorite sport is soccer.' } } | ||
*/ | ||
|
||
const res3 = await chain.call({ input: "what's my name?" }); | ||
console.log({ res3 }); | ||
/* | ||
{ res3: { text: ' Your name is Perry.' } } | ||
*/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
35 changes: 35 additions & 0 deletions
35
langchain/src/memory/tests/vector_store_memory.int.test.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import { test, expect } from "@jest/globals"; | ||
import "@tensorflow/tfjs-backend-cpu"; | ||
import { VectorStoreRetrieverMemory } from "../vector_store.js"; | ||
import { MemoryVectorStore } from "../../vectorstores/memory.js"; | ||
import { TensorFlowEmbeddings } from "../../embeddings/tensorflow.js"; | ||
import { Document } from "../../document.js"; | ||
|
||
test("Test vector store memory", async () => { | ||
const vectorStore = new MemoryVectorStore(new TensorFlowEmbeddings()); | ||
const memory = new VectorStoreRetrieverMemory({ | ||
vectorStoreRetriever: vectorStore.asRetriever(), | ||
}); | ||
const result1 = await memory.loadMemoryVariables({ input: "foo" }); | ||
expect(result1).toStrictEqual({ memory: "" }); | ||
|
||
await memory.saveContext({ foo: "bar" }, { bar: "foo" }); | ||
const expectedString = "foo: bar\nbar: foo"; | ||
const result2 = await memory.loadMemoryVariables({ input: "foo" }); | ||
expect(result2).toStrictEqual({ memory: expectedString }); | ||
}); | ||
|
||
test("Test vector store memory return docs", async () => { | ||
const vectorStore = new MemoryVectorStore(new TensorFlowEmbeddings()); | ||
const memory = new VectorStoreRetrieverMemory({ | ||
vectorStoreRetriever: vectorStore.asRetriever(), | ||
returnDocs: true, | ||
}); | ||
const result1 = await memory.loadMemoryVariables({ input: "foo" }); | ||
expect(result1).toStrictEqual({ memory: [] }); | ||
|
||
await memory.saveContext({ foo: "bar" }, { bar: "foo" }); | ||
const expectedResult = [new Document({ pageContent: "foo: bar\nbar: foo" })]; | ||
const result2 = await memory.loadMemoryVariables({ input: "foo" }); | ||
expect(result2).toStrictEqual({ memory: expectedResult }); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import { Document } from "../document.js"; | ||
import { VectorStoreRetriever } from "../vectorstores/base.js"; | ||
import { | ||
BaseMemory, | ||
getInputValue, | ||
InputValues, | ||
MemoryVariables, | ||
OutputValues, | ||
} from "./base.js"; | ||
|
||
export interface VectorStoreRetrieverMemoryParams { | ||
vectorStoreRetriever: VectorStoreRetriever; | ||
inputKey?: string; | ||
outputKey?: string; | ||
memoryKey?: string; | ||
returnDocs?: boolean; | ||
} | ||
|
||
export class VectorStoreRetrieverMemory | ||
extends BaseMemory | ||
implements VectorStoreRetrieverMemoryParams | ||
{ | ||
vectorStoreRetriever: VectorStoreRetriever; | ||
|
||
inputKey?: string; | ||
|
||
memoryKey: string; | ||
|
||
returnDocs: boolean; | ||
|
||
constructor(fields: VectorStoreRetrieverMemoryParams) { | ||
super(); | ||
this.vectorStoreRetriever = fields.vectorStoreRetriever; | ||
this.inputKey = fields.inputKey; | ||
this.memoryKey = fields.memoryKey ?? "memory"; | ||
this.returnDocs = fields.returnDocs ?? false; | ||
} | ||
|
||
get memoryKeys(): string[] { | ||
return [this.memoryKey]; | ||
} | ||
|
||
async loadMemoryVariables(values: InputValues): Promise<MemoryVariables> { | ||
const query = getInputValue(values, this.inputKey); | ||
const results = await this.vectorStoreRetriever.getRelevantDocuments(query); | ||
return { | ||
[this.memoryKey]: this.returnDocs | ||
? results | ||
: results.map((r) => r.pageContent).join("\n"), | ||
}; | ||
} | ||
|
||
async saveContext( | ||
inputValues: InputValues, | ||
outputValues: OutputValues | ||
): Promise<void> { | ||
const text = Object.entries(inputValues) | ||
.filter(([k]) => k !== this.memoryKey) | ||
.concat(Object.entries(outputValues)) | ||
.map(([k, v]) => `${k}: ${v}`) | ||
.join("\n"); | ||
await this.vectorStoreRetriever.addDocuments([ | ||
new Document({ pageContent: text }), | ||
]); | ||
} | ||
} |