Skip to content

Commit

Permalink
langchain[patch]: Add possibility to rerank retrieved docs in ParentD…
Browse files Browse the repository at this point in the history
…ocumentRetriever and MultiQueryRetriever (langchain-ai#4738)

* feat: add Document Compressor to chain to allow rerank

* feat: add example

* fix: typo

* feat: rerank child documents instead of parents one

* feat: improve example

* feat: add compressor to Multi Query Retreiver

* feat: remove example

* feat: remove example

* fix: missing docs

* feat: remove default value of threshold score to adjust to new Cohere models better

* feat: make filtering optional and configurable

* docs: add examples

* fix: type checking so it allow for build

* Fix lint

---------

Co-authored-by: jacoblee93 <[email protected]>
  • Loading branch information
karol-f and jacoblee93 authored Apr 22, 2024
1 parent c35fd25 commit 6407078
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import CodeBlock from "@theme/CodeBlock";
import Example from "@examples/retrievers/parent_document_retriever.ts";
import ExampleWithScoreThreshold from "@examples/retrievers/parent_document_retriever_score_threshold.ts";
import ExampleWithChunkHeader from "@examples/retrievers/parent_document_retriever_chunk_header.ts";
import ExampleWithRerank from "@examples/retrievers/parent_document_retriever_rerank.ts";

# Parent Document Retriever

Expand Down Expand Up @@ -50,3 +51,12 @@ Tagging each document with metadata is a solution if you know what to filter aga
This is particularly important if you have several fine-grained child chunks that need to be correctly retrieved from the vector store.

<CodeBlock language="typescript">{ExampleWithChunkHeader}</CodeBlock>

## With Reranking

With many documents from the vector store that are passed to LLM, final answers sometimes consist of information from
irrelevant chunks, making it less precise and sometimes incorrect. Also, passing multiple irrelevant documents makes it
more expensive.
So there are two reasons to use rerank - precision and costs.

<CodeBlock language="typescript">{ExampleWithRerank}</CodeBlock>
93 changes: 93 additions & 0 deletions examples/src/retrievers/parent_document_retriever_rerank.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import { OpenAIEmbeddings } from "@langchain/openai";
import { CohereRerank } from "@langchain/cohere";
import { HNSWLib } from "@langchain/community/vectorstores/hnswlib";
import { InMemoryStore } from "langchain/storage/in_memory";
import {
ParentDocumentRetriever,
type SubDocs,
} from "langchain/retrievers/parent_document";
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter";

// init Cohere Rerank. Remember to add COHERE_API_KEY to your .env
const reranker = new CohereRerank({
topN: 50,
model: "rerank-multilingual-v2.0",
});

export function documentCompressorFiltering({
relevanceScore,
}: { relevanceScore?: number } = {}) {
return (docs: SubDocs) => {
let outputDocs = docs;

if (relevanceScore) {
const docsRelevanceScoreValues = docs.map(
(doc) => doc?.metadata?.relevanceScore
);
outputDocs = docs.filter(
(_doc, index) =>
(docsRelevanceScoreValues?.[index] || 1) >= relevanceScore
);
}

return outputDocs;
};
}

const splitter = new RecursiveCharacterTextSplitter({
chunkSize: 500,
chunkOverlap: 0,
});

const jimDocs = await splitter.createDocuments([`Jim favorite color is blue.`]);

const pamDocs = await splitter.createDocuments([`Pam favorite color is red.`]);

const vectorstore = await HNSWLib.fromDocuments([], new OpenAIEmbeddings());
const docstore = new InMemoryStore();

const retriever = new ParentDocumentRetriever({
vectorstore,
docstore,
// Very small chunks for demo purposes.
// Use a bigger chunk size for serious use-cases.
childSplitter: new RecursiveCharacterTextSplitter({
chunkSize: 10,
chunkOverlap: 0,
}),
childK: 50,
parentK: 5,
// We add Reranker
documentCompressor: reranker,
documentCompressorFilteringFn: documentCompressorFiltering({
relevanceScore: 0.3,
}),
});

const docs = jimDocs.concat(pamDocs);
await retriever.addDocuments(docs);

// This will search for documents in vector store and return for LLM already reranked and sorted document
// with appropriate minimum relevance score
const retrievedDocs = await retriever.getRelevantDocuments(
"What is Pam's favorite color?"
);

// Pam's favorite color is returned first!
console.log(JSON.stringify(retrievedDocs, null, 2));
/*
[
{
"pageContent": "My favorite color is red.",
"metadata": {
"relevanceScore": 0.9
"loc": {
"lines": {
"from": 1,
"to": 1
}
}
}
}
]
*/
42 changes: 34 additions & 8 deletions langchain/src/retrievers/multi_query.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,15 @@ import { BaseOutputParser } from "@langchain/core/output_parsers";
import { PromptTemplate, BasePromptTemplate } from "@langchain/core/prompts";
import { CallbackManagerForRetrieverRun } from "@langchain/core/callbacks/manager";
import { LLMChain } from "../chains/llm_chain.js";
import type { BaseDocumentCompressor } from "./document_compressors/index.js";

interface LineList {
lines: string[];
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
export type MultiDocs = Document<Record<string, any>>[];

class LineListOutputParser extends BaseOutputParser<LineList> {
static lc_name() {
return "LineListOutputParser";
Expand Down Expand Up @@ -66,6 +70,8 @@ export interface MultiQueryRetrieverInput extends BaseRetrieverInput {
llmChain: LLMChain<LineList>;
queryCount?: number;
parserKey?: string;
documentCompressor?: BaseDocumentCompressor | undefined;
documentCompressorFilteringFn?: (docs: MultiDocs) => MultiDocs;
}

/**
Expand Down Expand Up @@ -96,12 +102,18 @@ export class MultiQueryRetriever extends BaseRetriever {

private parserKey = "lines";

documentCompressor: BaseDocumentCompressor | undefined;

documentCompressorFilteringFn?: MultiQueryRetrieverInput["documentCompressorFilteringFn"];

constructor(fields: MultiQueryRetrieverInput) {
super(fields);
this.retriever = fields.retriever;
this.llmChain = fields.llmChain;
this.queryCount = fields.queryCount ?? this.queryCount;
this.parserKey = fields.parserKey ?? this.parserKey;
this.documentCompressor = fields.documentCompressor;
this.documentCompressorFilteringFn = fields.documentCompressorFilteringFn;
}

static fromLLM(
Expand Down Expand Up @@ -145,13 +157,15 @@ export class MultiQueryRetriever extends BaseRetriever {
runManager?: CallbackManagerForRetrieverRun
): Promise<Document[]> {
const documents: Document[] = [];
for (const query of queries) {
const docs = await this.retriever.getRelevantDocuments(
query,
runManager?.getChild()
);
documents.push(...docs);
}
await Promise.all(
queries.map(async (query) => {
const docs = await this.retriever.getRelevantDocuments(
query,
runManager?.getChild()
);
documents.push(...docs);
})
);
return documents;
}

Expand All @@ -177,6 +191,18 @@ export class MultiQueryRetriever extends BaseRetriever {
const queries = await this._generateQueries(question, runManager);
const documents = await this._retrieveDocuments(queries, runManager);
const uniqueDocuments = this._uniqueUnion(documents);
return uniqueDocuments;

let outputDocs = uniqueDocuments;
if (this.documentCompressor && uniqueDocuments.length) {
outputDocs = await this.documentCompressor.compressDocuments(
uniqueDocuments,
question
);
if (this.documentCompressorFilteringFn) {
outputDocs = this.documentCompressorFilteringFn(outputDocs);
}
}

return outputDocs;
}
}
22 changes: 20 additions & 2 deletions langchain/src/retrievers/parent_document.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {
type VectorStoreRetrieverInterface,
} from "@langchain/core/vectorstores";
import { Document } from "@langchain/core/documents";
import type { BaseDocumentCompressor } from "./document_compressors/index.js";
import {
TextSplitter,
TextSplitterChunkHeaderOptions,
Expand All @@ -14,6 +15,9 @@ import {
type MultiVectorRetrieverInput,
} from "./multi_vector.js";

// eslint-disable-next-line @typescript-eslint/no-explicit-any
export type SubDocs = Document<Record<string, any>>[];

/**
* Interface for the fields required to initialize a
* ParentDocumentRetriever instance.
Expand All @@ -26,6 +30,8 @@ export type ParentDocumentRetrieverFields = MultiVectorRetrieverInput & {
* the `.similaritySearch` method of the vectorstore.
*/
childDocumentRetriever?: VectorStoreRetrieverInterface<VectorStoreInterface>;
documentCompressor?: BaseDocumentCompressor | undefined;
documentCompressorFilteringFn?: (docs: SubDocs) => SubDocs;
};

/**
Expand Down Expand Up @@ -81,6 +87,10 @@ export class ParentDocumentRetriever extends MultiVectorRetriever {
| VectorStoreRetrieverInterface<VectorStoreInterface>
| undefined;

documentCompressor: BaseDocumentCompressor | undefined;

documentCompressorFilteringFn?: ParentDocumentRetrieverFields["documentCompressorFilteringFn"];

constructor(fields: ParentDocumentRetrieverFields) {
super(fields);
this.vectorstore = fields.vectorstore;
Expand All @@ -90,17 +100,25 @@ export class ParentDocumentRetriever extends MultiVectorRetriever {
this.childK = fields.childK;
this.parentK = fields.parentK;
this.childDocumentRetriever = fields.childDocumentRetriever;
this.documentCompressor = fields.documentCompressor;
this.documentCompressorFilteringFn = fields.documentCompressorFilteringFn;
}

async _getRelevantDocuments(query: string): Promise<Document[]> {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let subDocs: Document<Record<string, any>>[] = [];
let subDocs: SubDocs = [];
if (this.childDocumentRetriever) {
subDocs = await this.childDocumentRetriever.getRelevantDocuments(query);
} else {
subDocs = await this.vectorstore.similaritySearch(query, this.childK);
}

if (this.documentCompressor && subDocs.length) {
subDocs = await this.documentCompressor.compressDocuments(subDocs, query);
if (this.documentCompressorFilteringFn) {
subDocs = this.documentCompressorFilteringFn(subDocs);
}
}

// Maintain order
const parentDocIds: string[] = [];
for (const doc of subDocs) {
Expand Down

0 comments on commit 6407078

Please sign in to comment.