forked from langchain-ai/langchainjs
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding contextual compression with an LLMChainExtractor (langchain-ai…
- Loading branch information
Showing
15 changed files
with
214 additions
and
0 deletions.
There are no files selected for viewing
15 changes: 15 additions & 0 deletions
15
docs/docs/modules/indexes/retrievers/contextual-compression-retriever.mdx
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Contextual Compression Retriever | ||
|
||
A Contextual Compression Retriever is designed to improve the answers returned from vector store document similarity searches by better taking into account the context from the query. | ||
|
||
It wraps another retriever, and uses a Document Compressor as an intermediate step after the initial similarity search that removes information irrelevant to the initial query from the retrieved documents. | ||
This reduces the amount of distraction a subsequent chain has to deal with when parsing the retrieved documents and making its final judgements. | ||
|
||
## Usage | ||
|
||
This example shows how to intialize a `ContextualCompressionRetriever` with a vector store and a document compressor: | ||
|
||
import CodeBlock from "@theme/CodeBlock"; | ||
import Example from "@examples/retrievers/contextual_compression.ts"; | ||
|
||
<CodeBlock language="typescript">{Example}</CodeBlock> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import * as fs from "fs"; | ||
|
||
import { OpenAI } from "langchain/llms/openai"; | ||
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"; | ||
import { RetrievalQAChain } from "langchain/chains"; | ||
import { HNSWLib } from "langchain/vectorstores/hnswlib"; | ||
import { OpenAIEmbeddings } from "langchain/embeddings/openai"; | ||
import { ContextualCompressionRetriever } from "langchain/retrievers/contextual_compression"; | ||
import { LLMChainExtractor } from "langchain/retrievers/document_compressors/chain_extract"; | ||
|
||
const model = new OpenAI(); | ||
const baseCompressor = LLMChainExtractor.fromLLM(model); | ||
|
||
const text = fs.readFileSync("state_of_the_union.txt", "utf8"); | ||
|
||
const textSplitter = new RecursiveCharacterTextSplitter({ chunkSize: 1000 }); | ||
const docs = await textSplitter.createDocuments([text]); | ||
|
||
// Create a vector store from the documents. | ||
const vectorStore = await HNSWLib.fromDocuments(docs, new OpenAIEmbeddings()); | ||
|
||
const retriever = new ContextualCompressionRetriever({ | ||
baseCompressor, | ||
baseRetriever: vectorStore.asRetriever(), | ||
}); | ||
|
||
const chain = RetrievalQAChain.fromLLM(model, retriever); | ||
|
||
const res = await chain.call({ | ||
query: "What did the speaker say about Justice Breyer?", | ||
}); | ||
|
||
console.log({ res }); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
90 changes: 90 additions & 0 deletions
90
langchain/src/retrievers/document_compressors/chain_extract.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import { Document } from "../../document.js"; | ||
import { LLMChain } from "../../chains/llm_chain.js"; | ||
import { PromptTemplate } from "../../prompts/index.js"; | ||
import { BaseLanguageModel } from "../../base_language/index.js"; | ||
import { BaseOutputParser } from "../../schema/output_parser.js"; | ||
import { BaseDocumentCompressor } from "./index.js"; | ||
import { PROMPT_TEMPLATE } from "./chain_extract_prompt.js"; | ||
|
||
function defaultGetInput( | ||
query: string, | ||
doc: Document | ||
): Record<string, unknown> { | ||
return { question: query, context: doc.pageContent }; | ||
} | ||
|
||
class NoOutputParser extends BaseOutputParser<string> { | ||
noOutputStr = "NO_OUTPUT"; | ||
|
||
parse(text: string): Promise<string> { | ||
const cleanedText = text.trim(); | ||
if (cleanedText === this.noOutputStr) { | ||
return Promise.resolve(""); | ||
} | ||
return Promise.resolve(cleanedText); | ||
} | ||
|
||
getFormatInstructions(): string { | ||
throw new Error("Method not implemented."); | ||
} | ||
} | ||
|
||
function getDefaultChainPrompt(): PromptTemplate { | ||
const outputParser = new NoOutputParser(); | ||
const template = PROMPT_TEMPLATE(outputParser.noOutputStr); | ||
return new PromptTemplate({ | ||
template, | ||
inputVariables: ["question", "context"], | ||
outputParser, | ||
}); | ||
} | ||
|
||
export interface LLMChainExtractorArgs { | ||
llmChain: LLMChain; | ||
getInput: (query: string, doc: Document) => Record<string, unknown>; | ||
} | ||
|
||
export class LLMChainExtractor extends BaseDocumentCompressor { | ||
llmChain: LLMChain; | ||
|
||
getInput: (query: string, doc: Document) => Record<string, unknown> = | ||
defaultGetInput; | ||
|
||
constructor({ llmChain, getInput }: LLMChainExtractorArgs) { | ||
super(); | ||
this.llmChain = llmChain; | ||
this.getInput = getInput; | ||
} | ||
|
||
async compressDocuments( | ||
documents: Document[], | ||
query: string | ||
): Promise<Document[]> { | ||
const compressedDocs: Document[] = []; | ||
for (const doc of documents) { | ||
const input = this.getInput(query, doc); | ||
const output = await this.llmChain.predict(input); | ||
if (output.length === 0) { | ||
continue; | ||
} | ||
compressedDocs.push( | ||
new Document({ | ||
pageContent: output, | ||
metadata: doc.metadata, | ||
}) | ||
); | ||
} | ||
return compressedDocs; | ||
} | ||
|
||
static fromLLM( | ||
llm: BaseLanguageModel, | ||
prompt?: PromptTemplate, | ||
getInput?: (query: string, doc: Document) => Record<string, unknown> | ||
): LLMChainExtractor { | ||
const _prompt = prompt || getDefaultChainPrompt(); | ||
const _getInput = getInput || defaultGetInput; | ||
const llmChain = new LLMChain({ llm, prompt: _prompt }); | ||
return new LLMChainExtractor({ llmChain, getInput: _getInput }); | ||
} | ||
} |
12 changes: 12 additions & 0 deletions
12
langchain/src/retrievers/document_compressors/chain_extract_prompt.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
export const PROMPT_TEMPLATE = ( | ||
noOutputStr: string | ||
) => `Given the following question and context, extract any part of the context *AS IS* that is relevant to answer the question. If none of the context is relevant return ${noOutputStr}. | ||
Remember, *DO NOT* edit the extracted parts of the context. | ||
> Question: {question} | ||
> Context: | ||
>>> | ||
{context} | ||
>>> | ||
Extracted relevant parts:`; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import { test, expect } from "@jest/globals"; | ||
import { OpenAI } from "../../llms/openai.js"; | ||
import { PromptTemplate } from "../../prompts/index.js"; | ||
import { LLMChain } from "../../chains/llm_chain.js"; | ||
import { StuffDocumentsChain } from "../../chains/combine_docs_chain.js"; | ||
import { ConversationalRetrievalQAChain } from "../../chains/conversational_retrieval_chain.js"; | ||
import { HNSWLib } from "../../vectorstores/hnswlib.js"; | ||
import { OpenAIEmbeddings } from "../../embeddings/openai.js"; | ||
import { ContextualCompressionRetriever } from "../contextual_compression.js"; | ||
import { LLMChainExtractor } from "../document_compressors/chain_extract.js"; | ||
|
||
test("Test LLMChainExtractor", async () => { | ||
const model = new OpenAI({ modelName: "text-ada-001" }); | ||
const prompt = PromptTemplate.fromTemplate( | ||
"Print {question}, and ignore {chat_history}" | ||
); | ||
const baseCompressor = LLMChainExtractor.fromLLM(model); | ||
expect(baseCompressor).toBeDefined(); | ||
|
||
const retriever = new ContextualCompressionRetriever({ | ||
baseCompressor, | ||
baseRetriever: await HNSWLib.fromTexts( | ||
["Hello world", "Bye bye", "hello nice world", "bye", "hi"], | ||
[{ id: 2 }, { id: 1 }, { id: 3 }, { id: 4 }, { id: 5 }], | ||
new OpenAIEmbeddings() | ||
).then((vectorStore) => vectorStore.asRetriever()), | ||
}); | ||
|
||
const llmChain = new LLMChain({ prompt, llm: model }); | ||
const combineDocsChain = new StuffDocumentsChain({ | ||
llmChain, | ||
documentVariableName: "foo", | ||
}); | ||
const chain = new ConversationalRetrievalQAChain({ | ||
retriever, | ||
combineDocumentsChain: combineDocsChain, | ||
questionGeneratorChain: llmChain, | ||
}); | ||
const res = await chain.call({ question: "foo", chat_history: "bar" }); | ||
|
||
expect(res.text.length).toBeGreaterThan(0); | ||
|
||
console.log({ res }); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters