Skip to content

Commit

Permalink
feat: added returnIntermediateSteps flag to MapReduceChain (langc…
Browse files Browse the repository at this point in the history
…hain-ai#1080)

* feat: added returnIntermediateSteps flag to MapReduceChain

* docs: addded extra documentation around the intermediate steps

* fix: linting issue

* docs: set map_reduce in backticks

* feat: added  option to loadQAChain as well

* Populate params via destructuring, adds unit test

* Lint

---------

Co-authored-by: Jacob Lee <[email protected]>
Co-authored-by: Nuno Campos <[email protected]>
  • Loading branch information
3 people authored May 4, 2023
1 parent 0eccc85 commit 0fe7aac
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 24 deletions.
7 changes: 7 additions & 0 deletions docs/docs/modules/chains/other_chains/summarization.mdx
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import CodeBlock from "@theme/CodeBlock";
import SummarizeExample from "@examples/chains/summarization_map_reduce.ts";
import SummarizeExampleIntermediateSteps from "@examples/chains/summarization_map_reduce_intermediate_steps.ts";

# Summarization

A summarization chain can be used to summarize multiple documents. One way is to input multiple smaller documents, after they have been divided into chunks, and operate over them with a `MapReduceDocumentsChain`. You can also choose instead for the chain that does summarization to be a StuffDocumentsChain, or a RefineDocumentsChain. See more about the differences between them [here](../index_related_chains/document_qa)

<CodeBlock language="typescript">{SummarizeExample}</CodeBlock>

## Intermediate Steps

We can also return the intermediate steps for `map_reduce` chains, should we want to inspect them. This is done with the `returnIntermediateSteps` parameter.

<CodeBlock language="typescript">{SummarizeExampleIntermediateSteps}</CodeBlock>
2 changes: 1 addition & 1 deletion examples/src/chains/summarization_map_reduce.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ export const run = async () => {
const docs = await textSplitter.createDocuments([text]);

// This convenience function creates a document chain prompted to summarize a set of documents.
const chain = loadSummarizationChain(model);
const chain = loadSummarizationChain(model, { type: "map_reduce" });
const res = await chain.call({
input_documents: docs,
});
Expand Down
36 changes: 36 additions & 0 deletions examples/src/chains/summarization_map_reduce_intermediate_steps.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import { OpenAI } from "langchain/llms/openai";
import { loadSummarizationChain } from "langchain/chains";
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter";
import * as fs from "fs";

export const run = async () => {
// In this example, we use a `MapReduceDocumentsChain` specifically prompted to summarize a set of documents.
const text = fs.readFileSync("state_of_the_union.txt", "utf8");
const model = new OpenAI({ temperature: 0 });
const textSplitter = new RecursiveCharacterTextSplitter({ chunkSize: 1000 });
const docs = await textSplitter.createDocuments([text]);

// This convenience function creates a document chain prompted to summarize a set of documents.
const chain = loadSummarizationChain(model, {
type: "map_reduce",
returnIntermediateSteps: true,
});
const res = await chain.call({
input_documents: docs,
});
console.log({ res });
/*
{
res: {
intermediateSteps: [
"In response to Russia's aggression in Ukraine, the United States has united with other freedom-loving nations to impose economic sanctions and hold Putin accountable. The U.S. Department of Justice is also assembling a task force to go after the crimes of Russian oligarchs and seize their ill-gotten gains.",
"The United States and its European allies are taking action to punish Russia for its invasion of Ukraine, including seizing assets, closing off airspace, and providing economic and military assistance to Ukraine. The US is also mobilizing forces to protect NATO countries and has released 30 million barrels of oil from its Strategic Petroleum Reserve to help blunt gas prices. The world is uniting in support of Ukraine and democracy, and the US stands with its Ukrainian-American citizens.",
" President Biden and Vice President Harris ran for office with a new economic vision for America, and have since passed the American Rescue Plan and the Bipartisan Infrastructure Law to help struggling families and rebuild America's infrastructure. This includes creating jobs, modernizing roads, airports, ports, and waterways, replacing lead pipes, providing affordable high-speed internet, and investing in American products to support American jobs.",
],
text: "President Biden is taking action to protect Americans from the COVID-19 pandemic and Russian aggression, providing economic relief, investing in infrastructure, creating jobs, and fighting inflation.
He is also proposing measures to reduce the cost of prescription drugs, protect voting rights, and reform the immigration system. The speaker is advocating for increased economic security, police reform, and the Equality Act, as well as providing support for veterans and military families.
The US is making progress in the fight against COVID-19, and the speaker is encouraging Americans to come together and work towards a brighter future.",
},
}
*/
};
30 changes: 30 additions & 0 deletions langchain/src/chains/combine_docs_chain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,16 @@ export class StuffDocumentsChain
}

export interface MapReduceDocumentsChainInput extends StuffDocumentsChainInput {
/** The maximum number of tokens before requiring to do the reduction */
maxTokens?: number;
/** The maximum number of iterations to run through the map */
maxIterations?: number;
/** Ensures that the map step is taken regardless of max tokens */
ensureMapStep?: boolean;
/** Chain to use to combine results of applying llm_chain to documents. */
combineDocumentChain: BaseChain;
/** Return the results of the map steps in the output. */
returnIntermediateSteps?: boolean;
}

/**
Expand All @@ -117,6 +123,8 @@ export class MapReduceDocumentsChain

documentVariableName = "context";

returnIntermediateSteps = false;

get inputKeys() {
return [this.inputKey, ...this.combineDocumentChain.inputKeys];
}
Expand All @@ -143,6 +151,7 @@ export class MapReduceDocumentsChain
this.inputKey = fields.inputKey ?? this.inputKey;
this.maxTokens = fields.maxTokens ?? this.maxTokens;
this.maxIterations = fields.maxIterations ?? this.maxIterations;
this.returnIntermediateSteps = fields.returnIntermediateSteps ?? false;
}

/** @ignore */
Expand All @@ -156,12 +165,16 @@ export class MapReduceDocumentsChain
const { [this.inputKey]: docs, ...rest } = values;

let currentDocs = docs as Document[];
let intermediateSteps: string[] = [];

// For each iteration, we'll use the `llmChain` to get a new result
for (let i = 0; i < this.maxIterations; i += 1) {
const inputs = currentDocs.map((d) => ({
[this.documentVariableName]: d.pageContent,
...rest,
}));

// Calculate the total tokens required in the input
const promises = inputs.map(async (i) => {
const prompt = await this.llmChain.prompt.format(i);
return this.llmChain.llm.getNumTokens(prompt);
Expand All @@ -173,6 +186,8 @@ export class MapReduceDocumentsChain

const canSkipMapStep = i !== 0 || !this.ensureMapStep;
const withinTokenLimit = length < this.maxTokens;
// If we can skip the map step, and we're within the token limit, we don't
// need to run the map step, so just break out of the loop.
if (canSkipMapStep && withinTokenLimit) {
break;
}
Expand All @@ -183,15 +198,30 @@ export class MapReduceDocumentsChain
);
const { outputKey } = this.llmChain;

// If the flag is set, then concat that to the intermediate steps
if (this.returnIntermediateSteps) {
intermediateSteps = intermediateSteps.concat(
results.map((r: ChainValues) => r[outputKey])
);
}

currentDocs = results.map((r: ChainValues) => ({
pageContent: r[outputKey],
}));
}

// Now, with the final result of all the inputs from the `llmChain`, we can
// run the `combineDocumentChain` over them.
const newInputs = { input_documents: currentDocs, ...rest };
const result = await this.combineDocumentChain.call(
newInputs,
runManager?.getChild()
);

// Return the intermediate steps results if the flag is set
if (this.returnIntermediateSteps) {
return { ...result, intermediateSteps };
}
return result;
}

Expand Down
4 changes: 4 additions & 0 deletions langchain/src/chains/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,13 @@ export {
export { VectorDBQAChain, VectorDBQAChainInput } from "./vector_db_qa.js";
export {
loadQAChain,
QAChainParams,
loadQAStuffChain,
StuffQAChainParams,
loadQAMapReduceChain,
MapReduceQAChainParams,
loadQARefineChain,
RefineQAChainParams,
} from "./question_answering/load.js";
export {
loadSummarizationChain,
Expand Down
51 changes: 30 additions & 21 deletions langchain/src/chains/question_answering/load.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
StuffDocumentsChain,
MapReduceDocumentsChain,
RefineDocumentsChain,
MapReduceDocumentsChainInput,
} from "../combine_docs_chain.js";
import { QA_PROMPT_SELECTOR, DEFAULT_QA_PROMPT } from "./stuff_prompts.js";
import {
Expand All @@ -18,33 +19,34 @@ import {
REFINE_PROMPT_SELECTOR,
} from "./refine_prompts.js";

interface qaChainParams {
prompt?: BasePromptTemplate;
combineMapPrompt?: BasePromptTemplate;
combinePrompt?: BasePromptTemplate;
questionPrompt?: BasePromptTemplate;
refinePrompt?: BasePromptTemplate;
type?: string;
verbose?: boolean;
}
export type QAChainParams =
| ({
type?: "stuff";
} & StuffQAChainParams)
| ({
type?: "map_reduce";
} & MapReduceQAChainParams)
| ({
type?: "refine";
} & RefineQAChainParams);

export const loadQAChain = (
llm: BaseLanguageModel,
params: qaChainParams = {}
params: QAChainParams = { type: "stuff" }
) => {
const {
prompt = DEFAULT_QA_PROMPT,
combineMapPrompt = DEFAULT_COMBINE_QA_PROMPT,
combinePrompt = COMBINE_PROMPT,
type = "stuff",
verbose,
} = params;
const { type, verbose } = params;
if (type === "stuff") {
const { prompt = DEFAULT_QA_PROMPT } = params;
const llmChain = new LLMChain({ prompt, llm, verbose });
const chain = new StuffDocumentsChain({ llmChain });
const chain = new StuffDocumentsChain({ llmChain, verbose });
return chain;
}
if (type === "map_reduce") {
const {
combineMapPrompt = DEFAULT_COMBINE_QA_PROMPT,
combinePrompt = COMBINE_PROMPT,
returnIntermediateSteps,
} = params;
const llmChain = new LLMChain({ prompt: combineMapPrompt, llm, verbose });
const combineLLMChain = new LLMChain({
prompt: combinePrompt,
Expand All @@ -54,10 +56,13 @@ export const loadQAChain = (
const combineDocumentChain = new StuffDocumentsChain({
llmChain: combineLLMChain,
documentVariableName: "summaries",
verbose,
});
const chain = new MapReduceDocumentsChain({
llmChain,
combineDocumentChain,
returnIntermediateSteps,
verbose,
});
return chain;
}
Expand All @@ -72,13 +77,14 @@ export const loadQAChain = (
const chain = new RefineDocumentsChain({
llmChain,
refineLLMChain,
verbose,
});
return chain;
}
throw new Error(`Invalid _type: ${type}`);
};

interface StuffQAChainParams {
export interface StuffQAChainParams {
prompt?: BasePromptTemplate;
verbose?: boolean;
}
Expand All @@ -93,7 +99,8 @@ export const loadQAStuffChain = (
return chain;
};

interface MapReduceQAChainParams {
export interface MapReduceQAChainParams {
returnIntermediateSteps?: MapReduceDocumentsChainInput["returnIntermediateSteps"];
combineMapPrompt?: BasePromptTemplate;
combinePrompt?: BasePromptTemplate;
verbose?: boolean;
Expand All @@ -107,6 +114,7 @@ export const loadQAMapReduceChain = (
combineMapPrompt = COMBINE_QA_PROMPT_SELECTOR.getPrompt(llm),
combinePrompt = COMBINE_PROMPT_SELECTOR.getPrompt(llm),
verbose,
returnIntermediateSteps,
} = params;
const llmChain = new LLMChain({ prompt: combineMapPrompt, llm, verbose });
const combineLLMChain = new LLMChain({ prompt: combinePrompt, llm, verbose });
Expand All @@ -117,11 +125,12 @@ export const loadQAMapReduceChain = (
const chain = new MapReduceDocumentsChain({
llmChain,
combineDocumentChain,
returnIntermediateSteps,
});
return chain;
};

interface RefineQAChainParams {
export interface RefineQAChainParams {
questionPrompt?: BasePromptTemplate;
refinePrompt?: BasePromptTemplate;
verbose?: boolean;
Expand Down
7 changes: 5 additions & 2 deletions langchain/src/chains/summarization/load.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {
StuffDocumentsChain,
MapReduceDocumentsChain,
RefineDocumentsChain,
MapReduceDocumentsChainInput,
} from "../combine_docs_chain.js";
import { DEFAULT_PROMPT } from "./stuff_prompts.js";
import { REFINE_PROMPT } from "./refine_prompts.js";
Expand All @@ -14,11 +15,11 @@ export type SummarizationChainParams =
type?: "stuff";
prompt?: BasePromptTemplate;
}
| {
| ({
type?: "map_reduce";
combineMapPrompt?: BasePromptTemplate;
combinePrompt?: BasePromptTemplate;
}
} & Pick<MapReduceDocumentsChainInput, "returnIntermediateSteps">)
| {
type?: "refine";
refinePrompt?: BasePromptTemplate;
Expand All @@ -42,6 +43,7 @@ export const loadSummarizationChain = (
const {
combineMapPrompt = DEFAULT_PROMPT,
combinePrompt = DEFAULT_PROMPT,
returnIntermediateSteps,
} = params;
const llmChain = new LLMChain({ prompt: combineMapPrompt, llm });
const combineLLMChain = new LLMChain({ prompt: combinePrompt, llm });
Expand All @@ -53,6 +55,7 @@ export const loadSummarizationChain = (
llmChain,
combineDocumentChain,
documentVariableName: "text",
returnIntermediateSteps,
});
return chain;
}
Expand Down
26 changes: 26 additions & 0 deletions langchain/src/chains/tests/combine_docs_chain.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,29 @@ test("Test MapReduceDocumentsChain with content above maxTokens", async () => {
expect(model.nrMapCalls).toBe(2); // above maxTokens
expect(model.nrReduceCalls).toBe(1);
});

test("Test MapReduceDocumentsChain with content above maxTokens and intermediate steps", async () => {
const model = new FakeLLM({});
const chain = loadQAMapReduceChain(model, {
returnIntermediateSteps: true,
});
const aString = "a".repeat(10000);
const bString = "b".repeat(10000);
const docs = [
new Document({ pageContent: aString }),
new Document({ pageContent: bString }),
];

const res = await chain.call({
input_documents: docs,
question: "Is the letter c present in the document",
});
console.log({ res });

expect(res).toEqual({
text: "a final answer",
intermediateSteps: ["a portion of context", "a portion of context"],
});
expect(model.nrMapCalls).toBe(2); // above maxTokens
expect(model.nrReduceCalls).toBe(1);
});

0 comments on commit 0fe7aac

Please sign in to comment.