Skip to content

Commit

Permalink
Added type to LLMChain to be used on predict (langchain-ai#1135)
Browse files Browse the repository at this point in the history
* Added generic type to LLMChain so predict can return string and object since we have structured output parser now

* Make the output type inferrable from output type of output parser

* Add default

---------

Co-authored-by: Nuno Campos <[email protected]>
  • Loading branch information
ppramesi and nfcampos authored May 5, 2023
1 parent 4e22af2 commit 478fcd9
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions langchain/src/chains/llm_chain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ import { SerializedLLMChain } from "./serde.js";
import { CallbackManager } from "../callbacks/index.js";
import { CallbackManagerForChainRun } from "../callbacks/manager.js";

export interface LLMChainInput extends ChainInputs {
export interface LLMChainInput<T extends string | object = string>
extends ChainInputs {
/** Prompt object to use */
prompt: BasePromptTemplate;
/** LLM Wrapper to use */
llm: BaseLanguageModel;
/** OutputParser to use */
outputParser?: BaseOutputParser;
outputParser?: BaseOutputParser<T>;
/** Key to use for output, defaults to `text` */
outputKey?: string;
}
Expand All @@ -31,14 +32,17 @@ export interface LLMChainInput extends ChainInputs {
* const llm = new LLMChain({ llm: new OpenAI(), prompt });
* ```
*/
export class LLMChain extends BaseChain implements LLMChainInput {
export class LLMChain<T extends string | object = string>
extends BaseChain
implements LLMChainInput<T>
{
prompt: BasePromptTemplate;

llm: BaseLanguageModel;

outputKey = "text";

outputParser?: BaseOutputParser;
outputParser?: BaseOutputParser<T>;

get inputKeys() {
return this.prompt.inputVariables;
Expand All @@ -48,7 +52,7 @@ export class LLMChain extends BaseChain implements LLMChainInput {
return [this.outputKey];
}

constructor(fields: LLMChainInput) {
constructor(fields: LLMChainInput<T>) {
super(fields);
this.prompt = fields.prompt;
this.llm = fields.llm;
Expand All @@ -58,7 +62,7 @@ export class LLMChain extends BaseChain implements LLMChainInput {
if (this.outputParser) {
throw new Error("Cannot set both outputParser and prompt.outputParser");
}
this.outputParser = this.prompt.outputParser;
this.outputParser = this.prompt.outputParser as BaseOutputParser<T>;
}
}

Expand Down Expand Up @@ -121,7 +125,7 @@ export class LLMChain extends BaseChain implements LLMChainInput {
async predict(
values: ChainValues,
callbackManager?: CallbackManager
): Promise<string> {
): Promise<T> {
const output = await this.call(values, callbackManager);
return output[this.outputKey];
}
Expand Down

0 comments on commit 478fcd9

Please sign in to comment.