Skip to content

Commit

Permalink
πŸ› fix: fix azure openai env and support enhanced custom models env (l…
Browse files Browse the repository at this point in the history
…obehub#2001)

* πŸ› fix: fix not enabled azure openai

* πŸ› fix: fix user define model meta not work

* πŸ› fix: use default server enabledModels

* ⚑️ perf: support more powerful env

* βœ… test: improve test

* ⚑️ perf: support Azure Model List

* βœ… test: fix test
  • Loading branch information
arvinxx authored Apr 12, 2024
1 parent 4743075 commit 899b784
Show file tree
Hide file tree
Showing 41 changed files with 594 additions and 388 deletions.
88 changes: 45 additions & 43 deletions src/app/api/config/__snapshots__/route.test.ts.snap
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,12 @@ exports[`GET /api/config > Model Provider env > CUSTOM_MODELS > custom deletion,
{
"displayName": "llama",
"enabled": true,
"functionCall": true,
"id": "llama",
"vision": true,
},
{
"displayName": "claude-2",
"enabled": true,
"functionCall": true,
"id": "claude-2",
"vision": true,
},
{
"displayName": "gpt-4-32k",
Expand All @@ -27,29 +23,32 @@ exports[`GET /api/config > Model Provider env > CUSTOM_MODELS > custom deletion,
`;

exports[`GET /api/config > Model Provider env > OPENAI_MODEL_LIST > custom deletion, addition, and renaming of models 1`] = `
[
{
"displayName": "llama",
"enabled": true,
"functionCall": true,
"id": "llama",
"vision": true,
},
{
"displayName": "claude-2",
"enabled": true,
"functionCall": true,
"id": "claude-2",
"vision": true,
},
{
"displayName": "gpt-4-32k",
"enabled": true,
"functionCall": true,
"id": "gpt-4-0125-preview",
"tokens": 128000,
},
]
{
"enabledModels": [
"llama",
"claude-2",
"gpt-4-0125-preview",
],
"serverModelCards": [
{
"displayName": "llama",
"enabled": true,
"id": "llama",
},
{
"displayName": "claude-2",
"enabled": true,
"id": "claude-2",
},
{
"displayName": "gpt-4-32k",
"enabled": true,
"functionCall": true,
"id": "gpt-4-0125-preview",
"tokens": 128000,
},
],
}
`;

exports[`GET /api/config > Model Provider env > OPENAI_MODEL_LIST > should work correct with gpt-4 1`] = `
Expand Down Expand Up @@ -108,20 +107,23 @@ exports[`GET /api/config > Model Provider env > OPENAI_MODEL_LIST > should work
`;

exports[`GET /api/config > Model Provider env > OPENROUTER_MODEL_LIST > custom deletion, addition, and renaming of models 1`] = `
[
{
"displayName": "google/gemma-7b-it",
"enabled": true,
"functionCall": true,
"id": "google/gemma-7b-it",
"vision": true,
},
{
"displayName": "Mistral-7B-Instruct",
"enabled": true,
"functionCall": true,
"id": "mistralai/mistral-7b-instruct",
"vision": true,
},
]
{
"enabled": false,
"enabledModels": [
"google/gemma-7b-it",
"mistralai/mistral-7b-instruct",
],
"serverModelCards": [
{
"displayName": "google/gemma-7b-it",
"enabled": true,
"id": "google/gemma-7b-it",
},
{
"displayName": "Mistral-7B-Instruct",
"enabled": true,
"id": "mistralai/mistral-7b-instruct",
},
],
}
`;
2 changes: 1 addition & 1 deletion src/app/api/config/parseDefaultAgent.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ describe('parseAgentConfig', () => {
});

describe('complex environment', () => {
it.skip('parses a complete environment variable string correctly', () => {
it('parses environment variable string correctly', () => {
const envStr =
'model=gpt-4-1106-preview;params.max_tokens=300;plugins=search-engine,lobe-image-designer';
const expected = {
Expand Down
12 changes: 2 additions & 10 deletions src/app/api/config/route.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ describe('GET /api/config', () => {

const jsonResponse: GlobalServerConfig = await response.json();

const result = jsonResponse.languageModel?.openai?.serverModelCards;
const result = jsonResponse.languageModel?.openai;

expect(result).toMatchSnapshot();
process.env.OPENAI_MODEL_LIST = '';
Expand Down Expand Up @@ -101,31 +101,23 @@ describe('GET /api/config', () => {

expect(result).toContainEqual({
displayName: 'model1',
functionCall: true,
id: 'model1',
enabled: true,
vision: true,
});
expect(result).toContainEqual({
displayName: 'model2',
functionCall: true,
enabled: true,
id: 'model2',
vision: true,
});
expect(result).toContainEqual({
displayName: 'model3',
enabled: true,
functionCall: true,
id: 'model3',
vision: true,
});
expect(result).toContainEqual({
displayName: 'model4',
functionCall: true,
enabled: true,
id: 'model4',
vision: true,
});

process.env.OPENAI_MODEL_LIST = '';
Expand Down Expand Up @@ -159,7 +151,7 @@ describe('GET /api/config', () => {
const res = await GET();
const data: GlobalServerConfig = await res.json();

const result = data.languageModel?.openrouter?.serverModelCards;
const result = data.languageModel?.openrouter;

expect(result).toMatchSnapshot();

Expand Down
53 changes: 37 additions & 16 deletions src/app/api/config/route.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import {
OllamaProviderCard,
OpenAIProviderCard,
OpenRouterProviderCard,
TogetherAIProviderCard,
} from '@/config/modelProviders';
import { getServerConfig } from '@/config/server';
import { GlobalServerConfig } from '@/types/serverConfig';
import { transformToChatModelCards } from '@/utils/parseModels';
import { extractEnabledModels, transformToChatModelCards } from '@/utils/parseModels';

import { parseAgentConfig } from './parseDefaultAgent';

Expand All @@ -31,6 +32,9 @@ export const GET = async () => {
ENABLED_ANTHROPIC,
ENABLED_MISTRAL,

ENABLED_AZURE_OPENAI,
AZURE_MODEL_LIST,

ENABLE_OLLAMA,
OLLAMA_MODEL_LIST,

Expand All @@ -49,38 +53,55 @@ export const GET = async () => {

enabledOAuthSSO: ENABLE_OAUTH_SSO,
languageModel: {
anthropic: { enabled: ENABLED_ANTHROPIC },
anthropic: {
enabled: ENABLED_ANTHROPIC,
},
azure: {
enabled: ENABLED_AZURE_OPENAI,
enabledModels: extractEnabledModels(AZURE_MODEL_LIST, true),
serverModelCards: transformToChatModelCards({
defaultChatModels: [],
modelString: AZURE_MODEL_LIST,
withDeploymentName: true,
}),
},
bedrock: { enabled: ENABLED_AWS_BEDROCK },
google: { enabled: ENABLED_GOOGLE },
groq: { enabled: ENABLED_GROQ },
mistral: { enabled: ENABLED_MISTRAL },
moonshot: { enabled: ENABLED_MOONSHOT },

ollama: {
enabled: ENABLE_OLLAMA,
serverModelCards: transformToChatModelCards(
OLLAMA_MODEL_LIST,
OllamaProviderCard.chatModels,
),
serverModelCards: transformToChatModelCards({
defaultChatModels: OllamaProviderCard.chatModels,
modelString: OLLAMA_MODEL_LIST,
}),
},
openai: {
serverModelCards: transformToChatModelCards(OPENAI_MODEL_LIST),
enabledModels: extractEnabledModels(OPENAI_MODEL_LIST),
serverModelCards: transformToChatModelCards({
defaultChatModels: OpenAIProviderCard.chatModels,
modelString: OPENAI_MODEL_LIST,
}),
},

openrouter: {
enabled: ENABLED_OPENROUTER,
serverModelCards: transformToChatModelCards(
OPENROUTER_MODEL_LIST,
OpenRouterProviderCard.chatModels,
),
enabledModels: extractEnabledModels(OPENROUTER_MODEL_LIST),
serverModelCards: transformToChatModelCards({
defaultChatModels: OpenRouterProviderCard.chatModels,
modelString: OPENROUTER_MODEL_LIST,
}),
},
perplexity: { enabled: ENABLED_PERPLEXITY },

togetherai: {
enabled: ENABLED_TOGETHERAI,
serverModelCards: transformToChatModelCards(
TOGETHERAI_MODEL_LIST,
TogetherAIProviderCard.chatModels,
),
enabledModels: extractEnabledModels(TOGETHERAI_MODEL_LIST),
serverModelCards: transformToChatModelCards({
defaultChatModels: TogetherAIProviderCard.chatModels,
modelString: TOGETHERAI_MODEL_LIST,
}),
},

zeroone: { enabled: ENABLED_ZEROONE },
Expand Down
16 changes: 3 additions & 13 deletions src/app/api/openai/createBizOpenAI/index.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
import OpenAI from 'openai';

import { checkAuth } from '@/app/api/auth';
import { getServerConfig } from '@/config/server';
import { getOpenAIAuthFromRequest } from '@/const/fetch';
import { ChatErrorType, ErrorType } from '@/types/fetch';

import { createErrorResponse } from '../../errorResponse';
import { createAzureOpenai } from './createAzureOpenai';
import { createOpenai } from './createOpenai';

/**
* createOpenAI Instance with Auth and azure openai support
* if auth not pass ,just return error response
*/
export const createBizOpenAI = (req: Request, model: string): Response | OpenAI => {
const { apiKey, accessCode, endpoint, useAzure, apiVersion, oauthAuthorized } =
getOpenAIAuthFromRequest(req);
export const createBizOpenAI = (req: Request): Response | OpenAI => {
const { apiKey, accessCode, endpoint, oauthAuthorized } = getOpenAIAuthFromRequest(req);

const result = checkAuth({ accessCode, apiKey, oauthAuthorized });

Expand All @@ -25,15 +22,8 @@ export const createBizOpenAI = (req: Request, model: string): Response | OpenAI

let openai: OpenAI;

const { USE_AZURE_OPENAI } = getServerConfig();
const useAzureOpenAI = useAzure || USE_AZURE_OPENAI;

try {
if (useAzureOpenAI) {
openai = createAzureOpenai({ apiVersion, endpoint, model, userApiKey: apiKey });
} else {
openai = createOpenai(apiKey, endpoint);
}
openai = createOpenai(apiKey, endpoint);
} catch (error) {
if ((error as Error).cause === ChatErrorType.NoOpenAIAPIKey) {
return createErrorResponse(ChatErrorType.NoOpenAIAPIKey);
Expand Down
2 changes: 1 addition & 1 deletion src/app/api/openai/images/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ export const runtime = 'edge';
export const POST = async (req: Request) => {
const payload = (await req.json()) as OpenAIImagePayload;

const openaiOrErrResponse = createBizOpenAI(req, payload.model);
const openaiOrErrResponse = createBizOpenAI(req);
// if resOrOpenAI is a Response, it means there is an error,just return it
if (openaiOrErrResponse instanceof Response) return openaiOrErrResponse;

Expand Down
2 changes: 1 addition & 1 deletion src/app/api/openai/stt/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export const POST = async (req: Request) => {
speech: speechBlob,
} as OpenAISTTPayload;

const openaiOrErrResponse = createBizOpenAI(req, payload.options.model);
const openaiOrErrResponse = createBizOpenAI(req);

// if resOrOpenAI is a Response, it means there is an error,just return it
if (openaiOrErrResponse instanceof Response) return openaiOrErrResponse;
Expand Down
2 changes: 1 addition & 1 deletion src/app/api/openai/tts/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export const preferredRegion = getPreferredRegion();
export const POST = async (req: Request) => {
const payload = (await req.json()) as OpenAITTSPayload;

const openaiOrErrResponse = createBizOpenAI(req, payload.options.model);
const openaiOrErrResponse = createBizOpenAI(req);

// if resOrOpenAI is a Response, it means there is an error,just return it
if (openaiOrErrResponse instanceof Response) return openaiOrErrResponse;
Expand Down
2 changes: 1 addition & 1 deletion src/app/chat/(desktop)/features/ChatHeader/Tags.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ const TitleTags = memo(() => {
agentSelectors.currentAgentPlugins(s),
]);

const showPlugin = useGlobalStore(modelProviderSelectors.modelEnabledFunctionCall(model));
const showPlugin = useGlobalStore(modelProviderSelectors.isModelEnabledFunctionCall(model));

return (
<Flexbox gap={8} horizontal>
Expand Down
2 changes: 1 addition & 1 deletion src/app/chat/(desktop)/features/ChatInput/Footer/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ const Footer = memo<{ setExpand?: (expand: boolean) => void }>(({ setExpand }) =
]);

const model = useSessionStore(agentSelectors.currentAgentModel);
const canUpload = useGlobalStore(modelProviderSelectors.modelEnabledUpload(model));
const canUpload = useGlobalStore(modelProviderSelectors.isModelEnabledUpload(model));

const sendMessage = useSendMessage();

Expand Down
4 changes: 2 additions & 2 deletions src/app/settings/llm/Azure/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { Flexbox } from 'react-layout-kit';

import { ModelProvider } from '@/libs/agent-runtime';
import { useGlobalStore } from '@/store/global';
import { modelConfigSelectors } from '@/store/global/selectors';
import { modelProviderSelectors } from '@/store/global/selectors';

import ProviderConfig from '../components/ProviderConfig';
import { LLMProviderApiTokenKey, LLMProviderBaseUrlKey, LLMProviderConfigKey } from '../const';
Expand All @@ -34,7 +34,7 @@ const AzureOpenAIProvider = memo(() => {

// Get the first model card's deployment name as the check model
const checkModel = useGlobalStore((s) => {
const chatModelCards = modelConfigSelectors.getModelCardsByProviderId(providerKey)(s);
const chatModelCards = modelProviderSelectors.getModelCardsById(providerKey)(s);

if (chatModelCards.length > 0) {
return chatModelCards[0].deploymentName;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ import { useTranslation } from 'react-i18next';
import { Flexbox } from 'react-layout-kit';

import { useGlobalStore } from '@/store/global';
import { modelConfigSelectors } from '@/store/global/selectors';
import {
modelConfigSelectors,
modelProviderSelectors,
settingsSelectors,
} from '@/store/global/selectors';
import { GlobalLLMProviderKey } from '@/types/settings';

const useStyles = createStyles(({ css, token }) => ({
Expand Down Expand Up @@ -38,10 +42,10 @@ const ModelFetcher = memo<ModelFetcherProps>(({ provider }) => {
]);
const enabledAutoFetch = useGlobalStore(modelConfigSelectors.isAutoFetchModelsEnabled(provider));
const latestFetchTime = useGlobalStore(
(s) => modelConfigSelectors.getConfigByProviderId(provider)(s)?.latestFetchTime,
(s) => settingsSelectors.providerConfig(provider)(s)?.latestFetchTime,
);
const totalModels = useGlobalStore(
(s) => modelConfigSelectors.getModelCardsByProviderId(provider)(s).length,
(s) => modelProviderSelectors.getModelCardsById(provider)(s).length,
);

const { mutate, isValidating } = useFetchProviderModelList(provider, enabledAutoFetch);
Expand Down
Loading

0 comments on commit 899b784

Please sign in to comment.