forked from langchain-ai/langchainjs
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added moderation chain (langchain-ai#1061)
* added moderation chain * changed the integration test example to something less extreme * Use fetch adapater, consistent naming, consistent args * Use async caller * added example and docs * formatting fix --------- Co-authored-by: Nuno Campos <[email protected]> Co-authored-by: Priya X. Pramesi <[email protected]>
- Loading branch information
1 parent
0fe7aac
commit 9ffb159
Showing
5 changed files
with
185 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
import CodeBlock from "@theme/CodeBlock"; | ||
import OpenAIModerationExample from "@examples/chains/openai_moderation.ts"; | ||
|
||
# `OpenAIModerationChain` | ||
|
||
You can use the `OpenAIModerationChain` which takes care of evaluating the input and identifying whether it violates OpenAI's Terms of Service (TOS). If the input contains any content that breaks the TOS and throwError is set to true, an error will be thrown and caught. If throwError is set to false the chain will return "Text was found that violates OpenAI's content policy." | ||
|
||
<CodeBlock language="typescript">{OpenAIModerationExample}</CodeBlock> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import { OpenAIModerationChain, LLMChain } from "langchain/chains"; | ||
import { PromptTemplate } from "langchain/prompts"; | ||
import { OpenAI } from "langchain/llms/openai"; | ||
|
||
// Define an asynchronous function called run | ||
export async function run() { | ||
// A string containing potentially offensive content from the user | ||
const badString = "Bad naughty words from user"; | ||
|
||
try { | ||
// Create a new instance of the OpenAIModerationChain | ||
const moderation = new OpenAIModerationChain(); | ||
|
||
// Send the user's input to the moderation chain and wait for the result | ||
const { output: badResult } = await moderation.call({ | ||
input: badString, | ||
throwError: true, // If set to true, the call will throw an error when the moderation chain detects violating content. If set to false, violating content will return "Text was found that violates OpenAI's content policy.". | ||
}); | ||
|
||
// If the moderation chain does not detect violating content, it will return the original input and you can proceed to use the result in another chain. | ||
const model = new OpenAI({ temperature: 0 }); | ||
const template = "Hello, how are you today {person}?"; | ||
const prompt = new PromptTemplate({ template, inputVariables: ["person"] }); | ||
const chainA = new LLMChain({ llm: model, prompt }); | ||
const resA = await chainA.call({ person: badResult }); | ||
console.log({ resA }); | ||
} catch (error) { | ||
// If an error is caught, it means the input contains content that violates OpenAI TOS | ||
console.error("Naughty words detected!"); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import { | ||
Configuration, | ||
OpenAIApi, | ||
ConfigurationParameters, | ||
CreateModerationRequest, | ||
CreateModerationResponseResultsInner, | ||
} from "openai"; | ||
import { BaseChain, ChainInputs } from "./base.js"; | ||
import { ChainValues } from "../schema/index.js"; | ||
import fetchAdapter from "../util/axios-fetch-adapter.js"; | ||
import { AsyncCaller, AsyncCallerParams } from "../util/async_caller.js"; | ||
|
||
export interface OpenAIModerationChainInput | ||
extends ChainInputs, | ||
AsyncCallerParams { | ||
openAIApiKey?: string; | ||
openAIOrganization?: string; | ||
throwError?: boolean; | ||
configuration?: ConfigurationParameters; | ||
} | ||
|
||
export class OpenAIModerationChain | ||
extends BaseChain | ||
implements OpenAIModerationChainInput | ||
{ | ||
inputKey = "input"; | ||
|
||
outputKey = "output"; | ||
|
||
openAIApiKey?: string; | ||
|
||
openAIOrganization?: string; | ||
|
||
clientConfig: Configuration; | ||
|
||
client: OpenAIApi; | ||
|
||
throwError: boolean; | ||
|
||
caller: AsyncCaller; | ||
|
||
constructor(fields?: OpenAIModerationChainInput) { | ||
super(fields); | ||
this.throwError = fields?.throwError ?? false; | ||
this.openAIApiKey = | ||
fields?.openAIApiKey ?? | ||
// eslint-disable-next-line no-process-env | ||
(typeof process !== "undefined" ? process.env.OPENAI_API_KEY : undefined); | ||
|
||
if (!this.openAIApiKey) { | ||
throw new Error("OpenAI API key not found"); | ||
} | ||
|
||
this.openAIOrganization = fields?.openAIOrganization; | ||
|
||
this.clientConfig = new Configuration({ | ||
...fields?.configuration, | ||
apiKey: this.openAIApiKey, | ||
organization: this.openAIOrganization, | ||
baseOptions: { | ||
adapter: fetchAdapter, | ||
...fields?.configuration?.baseOptions, | ||
}, | ||
}); | ||
|
||
this.client = new OpenAIApi(this.clientConfig); | ||
|
||
this.caller = new AsyncCaller(fields ?? {}); | ||
} | ||
|
||
_moderate( | ||
text: string, | ||
results: CreateModerationResponseResultsInner | ||
): string { | ||
if (results.flagged) { | ||
const errorStr = "Text was found that violates OpenAI's content policy."; | ||
if (this.throwError) { | ||
throw new Error(errorStr); | ||
} else { | ||
return errorStr; | ||
} | ||
} | ||
return text; | ||
} | ||
|
||
async _call(values: ChainValues): Promise<ChainValues> { | ||
const text = values[this.inputKey]; | ||
const moderationRequest: CreateModerationRequest = { | ||
input: text, | ||
}; | ||
let mod; | ||
try { | ||
mod = await this.caller.call(() => | ||
this.client.createModeration(moderationRequest) | ||
); | ||
} catch (error) { | ||
// eslint-disable-next-line no-instanceof/no-instanceof | ||
if (error instanceof Error) { | ||
throw error; | ||
} else { | ||
throw new Error(error as string); | ||
} | ||
} | ||
const output = this._moderate(text, mod.data.results[0]); | ||
return { | ||
[this.outputKey]: output, | ||
}; | ||
} | ||
|
||
_chainType() { | ||
return "moderation_chain"; | ||
} | ||
|
||
get inputKeys(): string[] { | ||
return [this.inputKey]; | ||
} | ||
|
||
get outputKeys(): string[] { | ||
return [this.outputKey]; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import { test } from "@jest/globals"; | ||
import { OpenAIModerationChain } from "../openai_moderation.js"; | ||
|
||
test("OpenAI Moderation Test", async () => { | ||
const badString = "I hate myself and want to do harm to myself"; | ||
const goodString = | ||
"The cat (Felis catus) is a domestic species of small carnivorous mammal."; | ||
|
||
const moderation = new OpenAIModerationChain(); | ||
const { output: badResult } = await moderation.call({ | ||
input: badString, | ||
}); | ||
|
||
const { output: goodResult } = await moderation.call({ | ||
input: goodString, | ||
}); | ||
|
||
expect(badResult).toEqual( | ||
"Text was found that violates OpenAI's content policy." | ||
); | ||
expect(goodResult).toEqual( | ||
"The cat (Felis catus) is a domestic species of small carnivorous mammal." | ||
); | ||
}); |