Skip to content

Commit

Permalink
Added moderation chain (langchain-ai#1061)
Browse files Browse the repository at this point in the history
* added moderation chain

* changed the integration test example to something less extreme

* Use fetch adapater, consistent naming, consistent args

* Use async caller

* added example and docs

* formatting fix

---------

Co-authored-by: Nuno Campos <[email protected]>
Co-authored-by: Priya X. Pramesi <[email protected]>
  • Loading branch information
3 people authored May 4, 2023
1 parent 0fe7aac commit 9ffb159
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 0 deletions.
8 changes: 8 additions & 0 deletions docs/docs/modules/chains/other_chains/moderation_chain.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import CodeBlock from "@theme/CodeBlock";
import OpenAIModerationExample from "@examples/chains/openai_moderation.ts";

# `OpenAIModerationChain`

You can use the `OpenAIModerationChain` which takes care of evaluating the input and identifying whether it violates OpenAI's Terms of Service (TOS). If the input contains any content that breaks the TOS and throwError is set to true, an error will be thrown and caught. If throwError is set to false the chain will return "Text was found that violates OpenAI's content policy."

<CodeBlock language="typescript">{OpenAIModerationExample}</CodeBlock>
31 changes: 31 additions & 0 deletions examples/src/chains/openai_moderation.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import { OpenAIModerationChain, LLMChain } from "langchain/chains";
import { PromptTemplate } from "langchain/prompts";
import { OpenAI } from "langchain/llms/openai";

// Define an asynchronous function called run
export async function run() {
// A string containing potentially offensive content from the user
const badString = "Bad naughty words from user";

try {
// Create a new instance of the OpenAIModerationChain
const moderation = new OpenAIModerationChain();

// Send the user's input to the moderation chain and wait for the result
const { output: badResult } = await moderation.call({
input: badString,
throwError: true, // If set to true, the call will throw an error when the moderation chain detects violating content. If set to false, violating content will return "Text was found that violates OpenAI's content policy.".
});

// If the moderation chain does not detect violating content, it will return the original input and you can proceed to use the result in another chain.
const model = new OpenAI({ temperature: 0 });
const template = "Hello, how are you today {person}?";
const prompt = new PromptTemplate({ template, inputVariables: ["person"] });
const chainA = new LLMChain({ llm: model, prompt });
const resA = await chainA.call({ person: badResult });
console.log({ resA });
} catch (error) {
// If an error is caught, it means the input contains content that violates OpenAI TOS
console.error("Naughty words detected!");
}
}
1 change: 1 addition & 0 deletions langchain/src/chains/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,4 @@ export {
SerializedVectorDBQAChain,
SerializedRefineDocumentsChain,
} from "./serde.js";
export { OpenAIModerationChain } from "./openai_moderation.js";
121 changes: 121 additions & 0 deletions langchain/src/chains/openai_moderation.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import {
Configuration,
OpenAIApi,
ConfigurationParameters,
CreateModerationRequest,
CreateModerationResponseResultsInner,
} from "openai";
import { BaseChain, ChainInputs } from "./base.js";
import { ChainValues } from "../schema/index.js";
import fetchAdapter from "../util/axios-fetch-adapter.js";
import { AsyncCaller, AsyncCallerParams } from "../util/async_caller.js";

export interface OpenAIModerationChainInput
extends ChainInputs,
AsyncCallerParams {
openAIApiKey?: string;
openAIOrganization?: string;
throwError?: boolean;
configuration?: ConfigurationParameters;
}

export class OpenAIModerationChain
extends BaseChain
implements OpenAIModerationChainInput
{
inputKey = "input";

outputKey = "output";

openAIApiKey?: string;

openAIOrganization?: string;

clientConfig: Configuration;

client: OpenAIApi;

throwError: boolean;

caller: AsyncCaller;

constructor(fields?: OpenAIModerationChainInput) {
super(fields);
this.throwError = fields?.throwError ?? false;
this.openAIApiKey =
fields?.openAIApiKey ??
// eslint-disable-next-line no-process-env
(typeof process !== "undefined" ? process.env.OPENAI_API_KEY : undefined);

if (!this.openAIApiKey) {
throw new Error("OpenAI API key not found");
}

this.openAIOrganization = fields?.openAIOrganization;

this.clientConfig = new Configuration({
...fields?.configuration,
apiKey: this.openAIApiKey,
organization: this.openAIOrganization,
baseOptions: {
adapter: fetchAdapter,
...fields?.configuration?.baseOptions,
},
});

this.client = new OpenAIApi(this.clientConfig);

this.caller = new AsyncCaller(fields ?? {});
}

_moderate(
text: string,
results: CreateModerationResponseResultsInner
): string {
if (results.flagged) {
const errorStr = "Text was found that violates OpenAI's content policy.";
if (this.throwError) {
throw new Error(errorStr);
} else {
return errorStr;
}
}
return text;
}

async _call(values: ChainValues): Promise<ChainValues> {
const text = values[this.inputKey];
const moderationRequest: CreateModerationRequest = {
input: text,
};
let mod;
try {
mod = await this.caller.call(() =>
this.client.createModeration(moderationRequest)
);
} catch (error) {
// eslint-disable-next-line no-instanceof/no-instanceof
if (error instanceof Error) {
throw error;
} else {
throw new Error(error as string);
}
}
const output = this._moderate(text, mod.data.results[0]);
return {
[this.outputKey]: output,
};
}

_chainType() {
return "moderation_chain";
}

get inputKeys(): string[] {
return [this.inputKey];
}

get outputKeys(): string[] {
return [this.outputKey];
}
}
24 changes: 24 additions & 0 deletions langchain/src/chains/tests/openai_moderation.int.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import { test } from "@jest/globals";
import { OpenAIModerationChain } from "../openai_moderation.js";

test("OpenAI Moderation Test", async () => {
const badString = "I hate myself and want to do harm to myself";
const goodString =
"The cat (Felis catus) is a domestic species of small carnivorous mammal.";

const moderation = new OpenAIModerationChain();
const { output: badResult } = await moderation.call({
input: badString,
});

const { output: goodResult } = await moderation.call({
input: goodString,
});

expect(badResult).toEqual(
"Text was found that violates OpenAI's content policy."
);
expect(goodResult).toEqual(
"The cat (Felis catus) is a domestic species of small carnivorous mammal."
);
});

0 comments on commit 9ffb159

Please sign in to comment.