From 9ffb159892669b3e9118741da3cfe697eb79623b Mon Sep 17 00:00:00 2001 From: "Priya X. Pramesi" Date: Thu, 4 May 2023 18:50:16 +0700 Subject: [PATCH] Added moderation chain (#1061) * added moderation chain * changed the integration test example to something less extreme * Use fetch adapater, consistent naming, consistent args * Use async caller * added example and docs * formatting fix --------- Co-authored-by: Nuno Campos Co-authored-by: Priya X. Pramesi --- .../chains/other_chains/moderation_chain.mdx | 8 ++ examples/src/chains/openai_moderation.ts | 31 +++++ langchain/src/chains/index.ts | 1 + langchain/src/chains/openai_moderation.ts | 121 ++++++++++++++++++ .../tests/openai_moderation.int.test.ts | 24 ++++ 5 files changed, 185 insertions(+) create mode 100644 docs/docs/modules/chains/other_chains/moderation_chain.mdx create mode 100644 examples/src/chains/openai_moderation.ts create mode 100644 langchain/src/chains/openai_moderation.ts create mode 100644 langchain/src/chains/tests/openai_moderation.int.test.ts diff --git a/docs/docs/modules/chains/other_chains/moderation_chain.mdx b/docs/docs/modules/chains/other_chains/moderation_chain.mdx new file mode 100644 index 000000000000..16432763e54b --- /dev/null +++ b/docs/docs/modules/chains/other_chains/moderation_chain.mdx @@ -0,0 +1,8 @@ +import CodeBlock from "@theme/CodeBlock"; +import OpenAIModerationExample from "@examples/chains/openai_moderation.ts"; + +# `OpenAIModerationChain` + +You can use the `OpenAIModerationChain` which takes care of evaluating the input and identifying whether it violates OpenAI's Terms of Service (TOS). If the input contains any content that breaks the TOS and throwError is set to true, an error will be thrown and caught. If throwError is set to false the chain will return "Text was found that violates OpenAI's content policy." + +{OpenAIModerationExample} diff --git a/examples/src/chains/openai_moderation.ts b/examples/src/chains/openai_moderation.ts new file mode 100644 index 000000000000..27bbc16dbeb0 --- /dev/null +++ b/examples/src/chains/openai_moderation.ts @@ -0,0 +1,31 @@ +import { OpenAIModerationChain, LLMChain } from "langchain/chains"; +import { PromptTemplate } from "langchain/prompts"; +import { OpenAI } from "langchain/llms/openai"; + +// Define an asynchronous function called run +export async function run() { + // A string containing potentially offensive content from the user + const badString = "Bad naughty words from user"; + + try { + // Create a new instance of the OpenAIModerationChain + const moderation = new OpenAIModerationChain(); + + // Send the user's input to the moderation chain and wait for the result + const { output: badResult } = await moderation.call({ + input: badString, + throwError: true, // If set to true, the call will throw an error when the moderation chain detects violating content. If set to false, violating content will return "Text was found that violates OpenAI's content policy.". + }); + + // If the moderation chain does not detect violating content, it will return the original input and you can proceed to use the result in another chain. + const model = new OpenAI({ temperature: 0 }); + const template = "Hello, how are you today {person}?"; + const prompt = new PromptTemplate({ template, inputVariables: ["person"] }); + const chainA = new LLMChain({ llm: model, prompt }); + const resA = await chainA.call({ person: badResult }); + console.log({ resA }); + } catch (error) { + // If an error is caught, it means the input contains content that violates OpenAI TOS + console.error("Naughty words detected!"); + } +} diff --git a/langchain/src/chains/index.ts b/langchain/src/chains/index.ts index 2f3b02b48f37..9ab3aa8ee166 100644 --- a/langchain/src/chains/index.ts +++ b/langchain/src/chains/index.ts @@ -60,3 +60,4 @@ export { SerializedVectorDBQAChain, SerializedRefineDocumentsChain, } from "./serde.js"; +export { OpenAIModerationChain } from "./openai_moderation.js"; diff --git a/langchain/src/chains/openai_moderation.ts b/langchain/src/chains/openai_moderation.ts new file mode 100644 index 000000000000..a7aa4a050c4c --- /dev/null +++ b/langchain/src/chains/openai_moderation.ts @@ -0,0 +1,121 @@ +import { + Configuration, + OpenAIApi, + ConfigurationParameters, + CreateModerationRequest, + CreateModerationResponseResultsInner, +} from "openai"; +import { BaseChain, ChainInputs } from "./base.js"; +import { ChainValues } from "../schema/index.js"; +import fetchAdapter from "../util/axios-fetch-adapter.js"; +import { AsyncCaller, AsyncCallerParams } from "../util/async_caller.js"; + +export interface OpenAIModerationChainInput + extends ChainInputs, + AsyncCallerParams { + openAIApiKey?: string; + openAIOrganization?: string; + throwError?: boolean; + configuration?: ConfigurationParameters; +} + +export class OpenAIModerationChain + extends BaseChain + implements OpenAIModerationChainInput +{ + inputKey = "input"; + + outputKey = "output"; + + openAIApiKey?: string; + + openAIOrganization?: string; + + clientConfig: Configuration; + + client: OpenAIApi; + + throwError: boolean; + + caller: AsyncCaller; + + constructor(fields?: OpenAIModerationChainInput) { + super(fields); + this.throwError = fields?.throwError ?? false; + this.openAIApiKey = + fields?.openAIApiKey ?? + // eslint-disable-next-line no-process-env + (typeof process !== "undefined" ? process.env.OPENAI_API_KEY : undefined); + + if (!this.openAIApiKey) { + throw new Error("OpenAI API key not found"); + } + + this.openAIOrganization = fields?.openAIOrganization; + + this.clientConfig = new Configuration({ + ...fields?.configuration, + apiKey: this.openAIApiKey, + organization: this.openAIOrganization, + baseOptions: { + adapter: fetchAdapter, + ...fields?.configuration?.baseOptions, + }, + }); + + this.client = new OpenAIApi(this.clientConfig); + + this.caller = new AsyncCaller(fields ?? {}); + } + + _moderate( + text: string, + results: CreateModerationResponseResultsInner + ): string { + if (results.flagged) { + const errorStr = "Text was found that violates OpenAI's content policy."; + if (this.throwError) { + throw new Error(errorStr); + } else { + return errorStr; + } + } + return text; + } + + async _call(values: ChainValues): Promise { + const text = values[this.inputKey]; + const moderationRequest: CreateModerationRequest = { + input: text, + }; + let mod; + try { + mod = await this.caller.call(() => + this.client.createModeration(moderationRequest) + ); + } catch (error) { + // eslint-disable-next-line no-instanceof/no-instanceof + if (error instanceof Error) { + throw error; + } else { + throw new Error(error as string); + } + } + const output = this._moderate(text, mod.data.results[0]); + return { + [this.outputKey]: output, + }; + } + + _chainType() { + return "moderation_chain"; + } + + get inputKeys(): string[] { + return [this.inputKey]; + } + + get outputKeys(): string[] { + return [this.outputKey]; + } +} diff --git a/langchain/src/chains/tests/openai_moderation.int.test.ts b/langchain/src/chains/tests/openai_moderation.int.test.ts new file mode 100644 index 000000000000..a6242c6e9cd2 --- /dev/null +++ b/langchain/src/chains/tests/openai_moderation.int.test.ts @@ -0,0 +1,24 @@ +import { test } from "@jest/globals"; +import { OpenAIModerationChain } from "../openai_moderation.js"; + +test("OpenAI Moderation Test", async () => { + const badString = "I hate myself and want to do harm to myself"; + const goodString = + "The cat (Felis catus) is a domestic species of small carnivorous mammal."; + + const moderation = new OpenAIModerationChain(); + const { output: badResult } = await moderation.call({ + input: badString, + }); + + const { output: goodResult } = await moderation.call({ + input: goodString, + }); + + expect(badResult).toEqual( + "Text was found that violates OpenAI's content policy." + ); + expect(goodResult).toEqual( + "The cat (Felis catus) is a domestic species of small carnivorous mammal." + ); +});