Skip to content

Commit

Permalink
allow choose model for chatgpt webapp
Browse files Browse the repository at this point in the history
  • Loading branch information
yaozhiwang committed Jun 17, 2023
1 parent da1e0e7 commit 69a887a
Show file tree
Hide file tree
Showing 10 changed files with 159 additions and 81 deletions.
6 changes: 3 additions & 3 deletions src/background/summarize.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ export async function summarize(port: chrome.runtime.Port, text: string) {
let provider: Provider
if (providerType === ProviderType.ChatGPTWebApp) {
const config = await storage.get<ChatGPTWebAppProviderConfig>(configKey)
provider = new ChatGPTWebAppProvider(prompt, config)
provider = new ChatGPTWebAppProvider(config)
} else if (providerType === ProviderType.OpenaiChatApi) {
const config = await storage.get<OpenAIProviderConfig>(configKey)
provider = new OpenAIChatProvider(prompt, config)
provider = new OpenAIChatProvider(config)
} else {
throw new Error(`Unknown provider ${providerType}`)
}
Expand All @@ -37,7 +37,7 @@ export async function summarize(port: chrome.runtime.Port, text: string) {
disconnected = true
cleanup?.()
})
const ret = await provider.summarize(text, {
const ret = await provider.summarize(prompt, text, {
signal: controller.signal,
onLoading(msg) {
!disconnected && port.postMessage({ loading: msg })
Expand Down
34 changes: 33 additions & 1 deletion src/config/provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,17 @@ export const ProviderTypeName = {
[ProviderType.OpenaiChatApi]: "OpenAI API"
}

export enum ChatGPTWebModelNames {
"text-davinci-002-render-sha" = "GPT-3.5",
"text-davinci-002-render-sha-mobile" = "GPT-3.5 (Mobile)",
"gpt-4" = "GPT-4",
"gpt-4-mobile" = "GPT-4 (Mobile)",
"gpt-4-browsing" = "GPT-4 Browsing"
}
export const defaultChatGPTWebModel = "text-davinci-002-render-sha"

export interface ChatGPTWebAppProviderConfig {
model: string
cleanup: boolean
}

Expand All @@ -28,7 +38,10 @@ export const defaultOpenaiAPIHost = "https://api.openai.com"
export const providerTypeConfigKey = "provider"
const defaultProviderConfig = {
[providerTypeConfigKey]: ProviderType.ChatGPTWebApp,
[getProviderConfigKey(ProviderType.ChatGPTWebApp)]: { cleanup: true },
[getProviderConfigKey(ProviderType.ChatGPTWebApp)]: {
cleanup: true,
model: defaultChatGPTWebModel
},
[getProviderConfigKey(ProviderType.OpenaiChatApi)]: {
apiKey: "",
apiHost: defaultOpenaiAPIHost,
Expand All @@ -52,13 +65,32 @@ export async function saveDefaultProviderConfigs() {
export async function migrateDefaultProviderConfigs(previousVersion: string) {
const storage = new Storage()

await migrateDefaultOpenaiApiConfig(storage, previousVersion)
await migrateDefaultChatGPTWebappConfig(storage, previousVersion)
}

async function migrateDefaultOpenaiApiConfig(
storage: Storage,
previousVersion: string
) {
const configKey = getProviderConfigKey(ProviderType.OpenaiChatApi)
const config = await storage.get<OpenAIProviderConfig>(configKey)
if (config.apiHost === undefined) {
await storage.set(configKey, { ...config, apiHost: defaultOpenaiAPIHost })
}
}

async function migrateDefaultChatGPTWebappConfig(
storage: Storage,
previousVersion: string
) {
const configKey = getProviderConfigKey(ProviderType.ChatGPTWebApp)
const config = await storage.get<ChatGPTWebAppProviderConfig>(configKey)
if (config.model === undefined) {
await storage.set(configKey, { ...config, model: defaultChatGPTWebModel })
}
}

export function getProviderConfigKey(provider: ProviderType) {
return `${providerTypeConfigKey}.${provider}`
}
Expand Down
9 changes: 9 additions & 0 deletions src/options/components/provider-select/chatgpt-webapp.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
ProviderType
} from "~config"
import { classNames } from "~utils"
import ModelSelect from "./model-select"

export default function ChatGPTWebAppProvider() {
const [config, setConfig] = useStorage<ChatGPTWebAppProviderConfig>(
Expand Down Expand Up @@ -45,6 +46,14 @@ export default function ChatGPTWebAppProvider() {
</span>
</Switch.Label>
</Switch.Group>
<div>
<label htmlFor="model" className="block text-lg font-medium">
Model
</label>
</div>
<div id="model">
<ModelSelect providerType={ProviderType.ChatGPTWebApp} />
</div>
</div>
) : null}
</>
Expand Down
6 changes: 3 additions & 3 deletions src/options/components/provider-select/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ import { RadioGroup } from "@headlessui/react"
import { useStorage } from "@plasmohq/storage/hook"
import {
ProviderType,
ProviderTypeName,
providerTypeConfigKey
providerTypeConfigKey,
ProviderTypeName
} from "~config/provider"
import { classNames } from "~utils"
import ChatGPTWebAppProvider from "./chatgpt-webapp"
Expand Down Expand Up @@ -37,7 +37,7 @@ export default function ProviderSelect() {
checked
? "border-0 bg-indigo-600 text-white hover:bg-indigo-700 dark:bg-indigo-500 dark:hover:bg-indigo-600"
: "border border-neutral-200 hover:bg-neutral-100 dark:border-neutral-500 dark:hover:bg-neutral-800",
"flex flex-1 cursor-pointer select-none items-center justify-center whitespace-nowrap rounded-md py-4 px-3 text-sm font-medium"
"flex flex-1 cursor-pointer select-none items-center justify-center whitespace-nowrap rounded-md px-3 py-4 text-sm font-medium"
)
}>
<RadioGroup.Label as="span">{option.name}</RadioGroup.Label>
Expand Down
64 changes: 32 additions & 32 deletions src/options/components/provider-select/model-select.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,51 +2,41 @@ import { Listbox, Transition } from "@headlessui/react"
import { useStorage } from "@plasmohq/storage/hook"
import { Fragment, useEffect, useState } from "react"
import {
ChatGPTWebModelNames,
getProviderConfigKey,
OpenAIProviderConfig,
ProviderType
} from "~config"
import { HiCheck, HiChevronUpDown } from "~icons"
import { Provider } from "~provider"
import { ChatGPTWebAppProvider } from "~provider/chatgpt-webapp"
import { OpenAIChatProvider } from "~provider/openai-chatapi"
import { classNames } from "~utils"

export default function ModelSelect() {
export default function ModelSelect(props: { providerType: ProviderType }) {
const { providerType } = props

const [models, setModels] = useState<string[]>([])
const [error, setError] = useState(false)
const [config, setConfig] = useStorage<OpenAIProviderConfig>(
getProviderConfigKey(ProviderType.OpenaiChatApi)
)
const [config, setConfig] = useStorage(getProviderConfigKey(providerType))

useEffect(() => {
if (!config) {
return
}

;(async () => {
let provider: Provider
if (providerType === ProviderType.ChatGPTWebApp) {
provider = new ChatGPTWebAppProvider(config)
} else if (providerType === ProviderType.OpenaiChatApi) {
provider = new OpenAIChatProvider(config)
} else {
throw new Error(`Unknown provider ${providerType}`)
}
try {
const resp = await fetch(`${config.apiHost}/v1/models`, {
method: "GET",
headers: {
Authorization: `Bearer ${config.apiKey}`
}
})
if (!resp.ok) {
setError(true)
}
const data = await resp.json().catch(() => {
setError(true)
})

// https://platform.openai.com/docs/models/model-endpoint-compatibility
const chatModels = []
for (const model of data.data) {
if (
model.id.startsWith("gpt-3.5-turbo") ||
model.id.startsWith("gpt-4")
) {
chatModels.push(model.id)
}
}
setModels(chatModels)
const models = await provider.fetchModels()
setModels(models)
setError(false)
} catch (err) {
setError(true)
Expand All @@ -58,6 +48,11 @@ export default function ModelSelect() {
return null
}

const getModelName = (model: string) => {
return providerType === ProviderType.ChatGPTWebApp
? ChatGPTWebModelNames[model] ?? model
: model
}
return (
<div className="flex flex-col gap-1">
<Listbox
Expand All @@ -69,7 +64,9 @@ export default function ModelSelect() {
<div className="block bg-white text-black dark:bg-neutral-900 dark:text-white">
<div className="relative w-[250]">
<Listbox.Button className="relative w-full cursor-default rounded-md py-1 pl-3 pr-10 text-left text-sm leading-6 shadow-sm ring-1 ring-inset ring-neutral-200 focus:outline-none focus:ring-2 focus:ring-indigo-500 dark:ring-neutral-500 dark:focus:ring-indigo-500">
<span className="ml-2 block truncate">{config.model}</span>
<span className="ml-2 block truncate">
{getModelName(config.model)}
</span>
<span className="pointer-events-none absolute inset-y-0 right-0 ml-3 flex items-center pr-2">
<HiChevronUpDown
className="h-5 w-5 text-neutral-200 dark:text-neutral-500"
Expand Down Expand Up @@ -103,7 +100,7 @@ export default function ModelSelect() {
selected ? "font-semibold" : "font-normal",
"ml-3 block truncate"
)}>
{model}
{getModelName(model)}
</span>
</div>

Expand All @@ -129,8 +126,11 @@ export default function ModelSelect() {
{error && (
<div className="text-xs text-red-600">
<p>
Failed to get model list, please check your network or API Host/API
Key setting.
Failed to get model list, please{" "}
{providerType === ProviderType.OpenaiChatApi
? "check your network or API Host/API Key setting"
: "login to chatgpt"}
.
</p>
<p>Will use model: {config.model}</p>
</div>
Expand Down
2 changes: 1 addition & 1 deletion src/options/components/provider-select/openai-api.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ export default function OpenAIAPIProvider() {
</label>
</div>
<div id="api-model">
<ModelSelect />
<ModelSelect providerType={ProviderType.OpenaiChatApi} />
</div>
</div>
) : null}
Expand Down
2 changes: 1 addition & 1 deletion src/provider/chatgpt-webapp/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class ChatGPTWebAppClient {
return resp
}

async getModels(
async fetchModels(
token: string
): Promise<
{ slug: string; title: string; description: string; max_tokens: number }[]
Expand Down
64 changes: 35 additions & 29 deletions src/provider/chatgpt-webapp/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { random } from "lodash-es"
import { v4 as uuidv4 } from "uuid"
import { Prompt } from "~config"
import type { ChatGPTWebAppProviderConfig } from "~config/provider"
import { ChatGPTWebAppProviderConfig } from "~config/provider"
import {
ProviderBackendError,
ProviderError,
Expand All @@ -10,6 +10,24 @@ import { parseSSEResponse } from "~utils/sse"
import { Provider, type SummarizeParams } from ".."
import { chatGPTWebAppClient } from "./client"

function generateRandomHex(length: number) {
let result = ""
const characters = "0123456789abcdef"
for (let i = 0; i < length; i++) {
result += characters.charAt(Math.floor(Math.random() * characters.length))
}
return result
}

function generateArkoseToken() {
return `${generateRandomHex(
17
)}|r=ap-southeast-1|meta=3|meta_width=300|metabgclr=transparent|metaiconclr=%23555555|guitextcolor=%23000000|pk=35536E1E-65B4-4D96-9D97-6ADB7EFF8147|at=40|sup=1|rid=${random(
1,
99
)}|ag=101|cdn_url=https%3A%2F%2Ftcr9i.chat.openai.com%2Fcdn%2Ffc|lurl=https%3A%2F%2Faudio-ap-southeast-1.arkoselabs.com|surl=https%3A%2F%2Ftcr9i.chat.openai.com|smurl=https%3A%2F%2Ftcr9i.chat.openai.com%2Fcdn%2Ffc%2Fassets%2Fstyle-manager`
}

interface ConversationContext {
conversationId: string
lastMessageId: string
Expand All @@ -19,34 +37,12 @@ export class ChatGPTWebAppProvider extends Provider {
#accessToken?: string
#config: ChatGPTWebAppProviderConfig
#conversationContext?: ConversationContext
#cachedModelNames?: string[]

constructor(prompt: Prompt, config: ChatGPTWebAppProviderConfig) {
super(prompt)
constructor(config: ChatGPTWebAppProviderConfig) {
super()
this.#config = config
}

private async fetchModelNames(): Promise<string[]> {
if (this.#cachedModelNames) {
return this.#cachedModelNames
}
const resp = await chatGPTWebAppClient.getModels(this.#accessToken!)
this.#cachedModelNames = resp
.map((r) => r.slug)
.filter((slug) => !slug.includes("plugins"))
return this.#cachedModelNames
}

private async getModelName(): Promise<string> {
try {
const modelNames = await this.fetchModelNames()
return modelNames[0]
} catch (err) {
console.error(err)
return "text-davinci-002-render"
}
}

async doSummarize(text: string, params: SummarizeParams) {
const cleanup = this.#config.cleanup
? () => {
Expand All @@ -68,7 +64,6 @@ export class ChatGPTWebAppProvider extends Provider {
if (!this.#accessToken) {
this.#accessToken = await chatGPTWebAppClient.getAccessToken()
}
const modelName = await this.getModelName()

try {
let result = ""
Expand All @@ -93,7 +88,10 @@ export class ChatGPTWebAppProvider extends Provider {
}
}
],
model: modelName,
model: this.#config.model,
arkose_token: this.#config.model.startsWith("gpt-4")
? generateArkoseToken()
: undefined,
conversation_id:
this.#conversationContext?.conversationId || undefined,
parent_message_id:
Expand All @@ -116,7 +114,7 @@ export class ChatGPTWebAppProvider extends Provider {
cleanup()
return
}
let data
let data: any
try {
data = JSON.parse(message)
} catch (err) {
Expand Down Expand Up @@ -150,4 +148,12 @@ export class ChatGPTWebAppProvider extends Provider {
resetConversation() {
this.#conversationContext = undefined
}

async fetchModels() {
if (!this.#accessToken) {
this.#accessToken = await chatGPTWebAppClient.getAccessToken()
}
const resp = await chatGPTWebAppClient.fetchModels(this.#accessToken!)
return resp.map((r) => r.slug).filter((slug) => !slug.includes("plugins"))
}
}
Loading

0 comments on commit 69a887a

Please sign in to comment.