Skip to content

Commit

Permalink
Adding contextual compression with an LLMChainExtractor (langchain-ai…
Browse files Browse the repository at this point in the history
  • Loading branch information
joseanu authored May 4, 2023
1 parent 0de32d2 commit 220c47e
Show file tree
Hide file tree
Showing 15 changed files with 214 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Contextual Compression Retriever

A Contextual Compression Retriever is designed to improve the answers returned from vector store document similarity searches by better taking into account the context from the query.

It wraps another retriever, and uses a Document Compressor as an intermediate step after the initial similarity search that removes information irrelevant to the initial query from the retrieved documents.
This reduces the amount of distraction a subsequent chain has to deal with when parsing the retrieved documents and making its final judgements.

## Usage

This example shows how to intialize a `ContextualCompressionRetriever` with a vector store and a document compressor:

import CodeBlock from "@theme/CodeBlock";
import Example from "@examples/retrievers/contextual_compression.ts";

<CodeBlock language="typescript">{Example}</CodeBlock>
33 changes: 33 additions & 0 deletions examples/src/retrievers/contextual_compression.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import * as fs from "fs";

import { OpenAI } from "langchain/llms/openai";
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter";
import { RetrievalQAChain } from "langchain/chains";
import { HNSWLib } from "langchain/vectorstores/hnswlib";
import { OpenAIEmbeddings } from "langchain/embeddings/openai";
import { ContextualCompressionRetriever } from "langchain/retrievers/contextual_compression";
import { LLMChainExtractor } from "langchain/retrievers/document_compressors/chain_extract";

const model = new OpenAI();
const baseCompressor = LLMChainExtractor.fromLLM(model);

const text = fs.readFileSync("state_of_the_union.txt", "utf8");

const textSplitter = new RecursiveCharacterTextSplitter({ chunkSize: 1000 });
const docs = await textSplitter.createDocuments([text]);

// Create a vector store from the documents.
const vectorStore = await HNSWLib.fromDocuments(docs, new OpenAIEmbeddings());

const retriever = new ContextualCompressionRetriever({
baseCompressor,
baseRetriever: vectorStore.asRetriever(),
});

const chain = RetrievalQAChain.fromLLM(model, retriever);

const res = await chain.call({
query: "What did the speaker say about Justice Breyer?",
});

console.log({ res });
3 changes: 3 additions & 0 deletions langchain/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ retrievers/contextual_compression.d.ts
retrievers/document_compressors.cjs
retrievers/document_compressors.js
retrievers/document_compressors.d.ts
retrievers/document_compressors/chain_extract.cjs
retrievers/document_compressors/chain_extract.js
retrievers/document_compressors/chain_extract.d.ts
retrievers/hyde.cjs
retrievers/hyde.js
retrievers/hyde.d.ts
Expand Down
8 changes: 8 additions & 0 deletions langchain/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,9 @@
"retrievers/document_compressors.cjs",
"retrievers/document_compressors.js",
"retrievers/document_compressors.d.ts",
"retrievers/document_compressors/chain_extract.cjs",
"retrievers/document_compressors/chain_extract.js",
"retrievers/document_compressors/chain_extract.d.ts",
"retrievers/hyde.cjs",
"retrievers/hyde.js",
"retrievers/hyde.d.ts",
Expand Down Expand Up @@ -923,6 +926,11 @@
"import": "./retrievers/document_compressors.js",
"require": "./retrievers/document_compressors.cjs"
},
"./retrievers/document_compressors/chain_extract": {
"types": "./retrievers/document_compressors/chain_extract.d.ts",
"import": "./retrievers/document_compressors/chain_extract.js",
"require": "./retrievers/document_compressors/chain_extract.cjs"
},
"./retrievers/hyde": {
"types": "./retrievers/hyde.d.ts",
"import": "./retrievers/hyde.js",
Expand Down
2 changes: 2 additions & 0 deletions langchain/scripts/create-entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ const entrypoints = {
"retrievers/databerry": "retrievers/databerry",
"retrievers/contextual_compression": "retrievers/contextual_compression",
"retrievers/document_compressors": "retrievers/document_compressors/index",
"retrievers/document_compressors/chain_extract":
"retrievers/document_compressors/chain_extract",
"retrievers/hyde": "retrievers/hyde",
// cache
cache: "cache/index",
Expand Down
90 changes: 90 additions & 0 deletions langchain/src/retrievers/document_compressors/chain_extract.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import { Document } from "../../document.js";
import { LLMChain } from "../../chains/llm_chain.js";
import { PromptTemplate } from "../../prompts/index.js";
import { BaseLanguageModel } from "../../base_language/index.js";
import { BaseOutputParser } from "../../schema/output_parser.js";
import { BaseDocumentCompressor } from "./index.js";
import { PROMPT_TEMPLATE } from "./chain_extract_prompt.js";

function defaultGetInput(
query: string,
doc: Document
): Record<string, unknown> {
return { question: query, context: doc.pageContent };
}

class NoOutputParser extends BaseOutputParser<string> {
noOutputStr = "NO_OUTPUT";

parse(text: string): Promise<string> {
const cleanedText = text.trim();
if (cleanedText === this.noOutputStr) {
return Promise.resolve("");
}
return Promise.resolve(cleanedText);
}

getFormatInstructions(): string {
throw new Error("Method not implemented.");
}
}

function getDefaultChainPrompt(): PromptTemplate {
const outputParser = new NoOutputParser();
const template = PROMPT_TEMPLATE(outputParser.noOutputStr);
return new PromptTemplate({
template,
inputVariables: ["question", "context"],
outputParser,
});
}

export interface LLMChainExtractorArgs {
llmChain: LLMChain;
getInput: (query: string, doc: Document) => Record<string, unknown>;
}

export class LLMChainExtractor extends BaseDocumentCompressor {
llmChain: LLMChain;

getInput: (query: string, doc: Document) => Record<string, unknown> =
defaultGetInput;

constructor({ llmChain, getInput }: LLMChainExtractorArgs) {
super();
this.llmChain = llmChain;
this.getInput = getInput;
}

async compressDocuments(
documents: Document[],
query: string
): Promise<Document[]> {
const compressedDocs: Document[] = [];
for (const doc of documents) {
const input = this.getInput(query, doc);
const output = await this.llmChain.predict(input);
if (output.length === 0) {
continue;
}
compressedDocs.push(
new Document({
pageContent: output,
metadata: doc.metadata,
})
);
}
return compressedDocs;
}

static fromLLM(
llm: BaseLanguageModel,
prompt?: PromptTemplate,
getInput?: (query: string, doc: Document) => Record<string, unknown>
): LLMChainExtractor {
const _prompt = prompt || getDefaultChainPrompt();
const _getInput = getInput || defaultGetInput;
const llmChain = new LLMChain({ llm, prompt: _prompt });
return new LLMChainExtractor({ llmChain, getInput: _getInput });
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
export const PROMPT_TEMPLATE = (
noOutputStr: string
) => `Given the following question and context, extract any part of the context *AS IS* that is relevant to answer the question. If none of the context is relevant return ${noOutputStr}.
Remember, *DO NOT* edit the extracted parts of the context.
> Question: {question}
> Context:
>>>
{context}
>>>
Extracted relevant parts:`;
44 changes: 44 additions & 0 deletions langchain/src/retrievers/tests/chain_extract.int.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import { test, expect } from "@jest/globals";
import { OpenAI } from "../../llms/openai.js";
import { PromptTemplate } from "../../prompts/index.js";
import { LLMChain } from "../../chains/llm_chain.js";
import { StuffDocumentsChain } from "../../chains/combine_docs_chain.js";
import { ConversationalRetrievalQAChain } from "../../chains/conversational_retrieval_chain.js";
import { HNSWLib } from "../../vectorstores/hnswlib.js";
import { OpenAIEmbeddings } from "../../embeddings/openai.js";
import { ContextualCompressionRetriever } from "../contextual_compression.js";
import { LLMChainExtractor } from "../document_compressors/chain_extract.js";

test("Test LLMChainExtractor", async () => {
const model = new OpenAI({ modelName: "text-ada-001" });
const prompt = PromptTemplate.fromTemplate(
"Print {question}, and ignore {chat_history}"
);
const baseCompressor = LLMChainExtractor.fromLLM(model);
expect(baseCompressor).toBeDefined();

const retriever = new ContextualCompressionRetriever({
baseCompressor,
baseRetriever: await HNSWLib.fromTexts(
["Hello world", "Bye bye", "hello nice world", "bye", "hi"],
[{ id: 2 }, { id: 1 }, { id: 3 }, { id: 4 }, { id: 5 }],
new OpenAIEmbeddings()
).then((vectorStore) => vectorStore.asRetriever()),
});

const llmChain = new LLMChain({ prompt, llm: model });
const combineDocsChain = new StuffDocumentsChain({
llmChain,
documentVariableName: "foo",
});
const chain = new ConversationalRetrievalQAChain({
retriever,
combineDocumentsChain: combineDocsChain,
questionGeneratorChain: llmChain,
});
const res = await chain.call({ question: "foo", chat_history: "bar" });

expect(res.text.length).toBeGreaterThan(0);

console.log({ res });
});
1 change: 1 addition & 0 deletions langchain/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
"src/retrievers/databerry.ts",
"src/retrievers/contextual_compression.ts",
"src/retrievers/document_compressors/index.ts",
"src/retrievers/document_compressors/chain_extract.ts",
"src/retrievers/hyde.ts",
"src/cache/index.ts",
"src/cache/redis.ts",
Expand Down
1 change: 1 addition & 0 deletions test-exports-cf/src/entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ export * from "langchain/retrievers/remote";
export * from "langchain/retrievers/databerry";
export * from "langchain/retrievers/contextual_compression";
export * from "langchain/retrievers/document_compressors";
export * from "langchain/retrievers/document_compressors/chain_extract";
export * from "langchain/retrievers/hyde";
export * from "langchain/cache";
export * from "langchain/stores/file/in_memory";
Expand Down
1 change: 1 addition & 0 deletions test-exports-cjs/src/entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ const retrievers_remote = require("langchain/retrievers/remote");
const retrievers_databerry = require("langchain/retrievers/databerry");
const retrievers_contextual_compression = require("langchain/retrievers/contextual_compression");
const retrievers_document_compressors = require("langchain/retrievers/document_compressors");
const retrievers_document_compressors_chain_extract = require("langchain/retrievers/document_compressors/chain_extract");
const retrievers_hyde = require("langchain/retrievers/hyde");
const cache = require("langchain/cache");
const stores_file_in_memory = require("langchain/stores/file/in_memory");
Expand Down
1 change: 1 addition & 0 deletions test-exports-cra/src/entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ export * from "langchain/retrievers/remote";
export * from "langchain/retrievers/databerry";
export * from "langchain/retrievers/contextual_compression";
export * from "langchain/retrievers/document_compressors";
export * from "langchain/retrievers/document_compressors/chain_extract";
export * from "langchain/retrievers/hyde";
export * from "langchain/cache";
export * from "langchain/stores/file/in_memory";
Expand Down
1 change: 1 addition & 0 deletions test-exports-esm/src/entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import * as retrievers_remote from "langchain/retrievers/remote";
import * as retrievers_databerry from "langchain/retrievers/databerry";
import * as retrievers_contextual_compression from "langchain/retrievers/contextual_compression";
import * as retrievers_document_compressors from "langchain/retrievers/document_compressors";
import * as retrievers_document_compressors_chain_extract from "langchain/retrievers/document_compressors/chain_extract";
import * as retrievers_hyde from "langchain/retrievers/hyde";
import * as cache from "langchain/cache";
import * as stores_file_in_memory from "langchain/stores/file/in_memory";
Expand Down
1 change: 1 addition & 0 deletions test-exports-vercel/src/entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ export * from "langchain/retrievers/remote";
export * from "langchain/retrievers/databerry";
export * from "langchain/retrievers/contextual_compression";
export * from "langchain/retrievers/document_compressors";
export * from "langchain/retrievers/document_compressors/chain_extract";
export * from "langchain/retrievers/hyde";
export * from "langchain/cache";
export * from "langchain/stores/file/in_memory";
Expand Down
1 change: 1 addition & 0 deletions test-exports-vite/src/entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ export * from "langchain/retrievers/remote";
export * from "langchain/retrievers/databerry";
export * from "langchain/retrievers/contextual_compression";
export * from "langchain/retrievers/document_compressors";
export * from "langchain/retrievers/document_compressors/chain_extract";
export * from "langchain/retrievers/hyde";
export * from "langchain/cache";
export * from "langchain/stores/file/in_memory";
Expand Down

0 comments on commit 220c47e

Please sign in to comment.