Skip to content

Commit

Permalink
chore: DRY-up the loadQAChain code so it's just using the same func…
Browse files Browse the repository at this point in the history
…tions (langchain-ai#1122)

* chore: DRY-up the load code so it's just using the same functions

* Lint

---------

Co-authored-by: Nuno Campos <[email protected]>
  • Loading branch information
justindra and nfcampos authored May 5, 2023
1 parent bd27d17 commit 085fcf4
Showing 1 changed file with 18 additions and 54 deletions.
72 changes: 18 additions & 54 deletions langchain/src/chains/question_answering/load.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@ import {
RefineDocumentsChain,
MapReduceDocumentsChainInput,
} from "../combine_docs_chain.js";
import { QA_PROMPT_SELECTOR, DEFAULT_QA_PROMPT } from "./stuff_prompts.js";
import { QA_PROMPT_SELECTOR } from "./stuff_prompts.js";
import {
COMBINE_PROMPT,
DEFAULT_COMBINE_QA_PROMPT,
COMBINE_PROMPT_SELECTOR,
COMBINE_QA_PROMPT_SELECTOR,
} from "./map_reduce_prompts.js";
Expand All @@ -34,52 +32,15 @@ export const loadQAChain = (
llm: BaseLanguageModel,
params: QAChainParams = { type: "stuff" }
) => {
const { type, verbose } = params;
const { type } = params;
if (type === "stuff") {
const { prompt = DEFAULT_QA_PROMPT } = params;
const llmChain = new LLMChain({ prompt, llm, verbose });
const chain = new StuffDocumentsChain({ llmChain, verbose });
return chain;
return loadQAStuffChain(llm, params);
}
if (type === "map_reduce") {
const {
combineMapPrompt = DEFAULT_COMBINE_QA_PROMPT,
combinePrompt = COMBINE_PROMPT,
returnIntermediateSteps,
} = params;
const llmChain = new LLMChain({ prompt: combineMapPrompt, llm, verbose });
const combineLLMChain = new LLMChain({
prompt: combinePrompt,
llm,
verbose,
});
const combineDocumentChain = new StuffDocumentsChain({
llmChain: combineLLMChain,
documentVariableName: "summaries",
verbose,
});
const chain = new MapReduceDocumentsChain({
llmChain,
combineDocumentChain,
returnIntermediateSteps,
verbose,
});
return chain;
return loadQAMapReduceChain(llm, params);
}
if (type === "refine") {
const {
questionPrompt = QUESTION_PROMPT_SELECTOR.getPrompt(llm),
refinePrompt = REFINE_PROMPT_SELECTOR.getPrompt(llm),
} = params;
const llmChain = new LLMChain({ prompt: questionPrompt, llm, verbose });
const refineLLMChain = new LLMChain({ prompt: refinePrompt, llm, verbose });

const chain = new RefineDocumentsChain({
llmChain,
refineLLMChain,
verbose,
});
return chain;
return loadQARefineChain(llm, params);
}
throw new Error(`Invalid _type: ${type}`);
};
Expand All @@ -89,15 +50,15 @@ export interface StuffQAChainParams {
verbose?: boolean;
}

export const loadQAStuffChain = (
export function loadQAStuffChain(
llm: BaseLanguageModel,
params: StuffQAChainParams = {}
) => {
) {
const { prompt = QA_PROMPT_SELECTOR.getPrompt(llm), verbose } = params;
const llmChain = new LLMChain({ prompt, llm, verbose });
const chain = new StuffDocumentsChain({ llmChain });
const chain = new StuffDocumentsChain({ llmChain, verbose });
return chain;
};
}

export interface MapReduceQAChainParams {
returnIntermediateSteps?: MapReduceDocumentsChainInput["returnIntermediateSteps"];
Expand All @@ -106,10 +67,10 @@ export interface MapReduceQAChainParams {
verbose?: boolean;
}

export const loadQAMapReduceChain = (
export function loadQAMapReduceChain(
llm: BaseLanguageModel,
params: MapReduceQAChainParams = {}
) => {
) {
const {
combineMapPrompt = COMBINE_QA_PROMPT_SELECTOR.getPrompt(llm),
combinePrompt = COMBINE_PROMPT_SELECTOR.getPrompt(llm),
Expand All @@ -121,25 +82,27 @@ export const loadQAMapReduceChain = (
const combineDocumentChain = new StuffDocumentsChain({
llmChain: combineLLMChain,
documentVariableName: "summaries",
verbose,
});
const chain = new MapReduceDocumentsChain({
llmChain,
combineDocumentChain,
returnIntermediateSteps,
verbose,
});
return chain;
};
}

export interface RefineQAChainParams {
questionPrompt?: BasePromptTemplate;
refinePrompt?: BasePromptTemplate;
verbose?: boolean;
}

export const loadQARefineChain = (
export function loadQARefineChain(
llm: BaseLanguageModel,
params: RefineQAChainParams = {}
) => {
) {
const {
questionPrompt = QUESTION_PROMPT_SELECTOR.getPrompt(llm),
refinePrompt = REFINE_PROMPT_SELECTOR.getPrompt(llm),
Expand All @@ -151,6 +114,7 @@ export const loadQARefineChain = (
const chain = new RefineDocumentsChain({
llmChain,
refineLLMChain,
verbose,
});
return chain;
};
}

0 comments on commit 085fcf4

Please sign in to comment.