Skip to content

Commit

Permalink
fix(memory): throw error if no input/output key is specified and ther…
Browse files Browse the repository at this point in the history
…e are multiple keys (langchain-ai#1778)

* fix(memory): throw error if no input/output key is specified and there are multiple keys

* Fix formatting

---------

Co-authored-by: jacoblee93 <[email protected]>
  • Loading branch information
joaopcm and jacoblee93 authored Jul 7, 2023
1 parent 8db82da commit f04352e
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 11 deletions.
44 changes: 36 additions & 8 deletions langchain/src/memory/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,50 @@ export abstract class BaseMemory {
): Promise<void>;
}

const getValue = (values: InputValues | OutputValues, key?: string) => {
if (key !== undefined) {
return values[key];
}
const keys = Object.keys(values);
if (keys.length === 1) {
return values[keys[0]];
}
};

/**
* This function is used by memory classes to select the input value
* to use for the memory. If there is only one input value, it is used.
* If there are multiple input values, the inputKey must be specified.
*/
export const getInputValue = (inputValues: InputValues, inputKey?: string) => {
if (inputKey !== undefined) {
return inputValues[inputKey];
const value = getValue(inputValues, inputKey);
if (!value) {
const keys = Object.keys(inputValues);
throw new Error(
`input values have ${keys.length} keys, you must specify an input key or pass only 1 key as input`
);
}
const keys = Object.keys(inputValues);
if (keys.length === 1) {
return inputValues[keys[0]];
return value;
};

/**
* This function is used by memory classes to select the output value
* to use for the memory. If there is only one output value, it is used.
* If there are multiple output values, the outputKey must be specified.
* If no outputKey is specified, an error is thrown.
*/
export const getOutputValue = (
outputValues: OutputValues,
outputKey?: string
) => {
const value = getValue(outputValues, outputKey);
if (!value) {
const keys = Object.keys(outputValues);
throw new Error(
`output values have ${keys.length} keys, you must specify an output key or pass only 1 key as output`
);
}
throw new Error(
`input values have ${keys.length} keys, you must specify an input key or pass only 1 key as input`
);
return value;
};

/**
Expand Down
3 changes: 2 additions & 1 deletion langchain/src/memory/chat_memory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
InputValues,
OutputValues,
getInputValue,
getOutputValue,
} from "./base.js";
import { ChatMessageHistory } from "../stores/message/in_memory.js";

Expand Down Expand Up @@ -40,7 +41,7 @@ export abstract class BaseChatMemory extends BaseMemory {
getInputValue(inputValues, this.inputKey)
);
await this.chatHistory.addAIChatMessage(
getInputValue(outputValues, this.outputKey)
getOutputValue(outputValues, this.outputKey)
);
}

Expand Down
3 changes: 2 additions & 1 deletion langchain/src/memory/motorhead_memory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {
MemoryVariables,
getBufferString,
getInputValue,
getOutputValue,
} from "./base.js";
import { AsyncCaller, AsyncCallerParams } from "../util/async_caller.js";

Expand Down Expand Up @@ -143,7 +144,7 @@ export class MotorheadMemory extends BaseChatMemory {
outputValues: OutputValues
): Promise<void> {
const input = getInputValue(inputValues, this.inputKey);
const output = getInputValue(outputValues, this.outputKey);
const output = getOutputValue(outputValues, this.outputKey);
await Promise.all([
this.caller.call(fetch, `${this.url}/sessions/${this.sessionId}/memory`, {
signal: this.timeout ? AbortSignal.timeout(this.timeout) : undefined,
Expand Down
3 changes: 2 additions & 1 deletion langchain/src/memory/zep.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { Memory, Message, NotFoundError, ZepClient } from "@getzep/zep-js";
import {
getBufferString,
getInputValue,
getOutputValue,
InputValues,
MemoryVariables,
OutputValues,
Expand Down Expand Up @@ -122,7 +123,7 @@ export class ZepMemory extends BaseChatMemory implements ZepMemoryInput {
outputValues: OutputValues
): Promise<void> {
const input = getInputValue(inputValues, this.inputKey);
const output = getInputValue(outputValues, this.outputKey);
const output = getOutputValue(outputValues, this.outputKey);

// Create new Memory and Message instances
const memory = new Memory({
Expand Down

0 comments on commit f04352e

Please sign in to comment.