Skip to content

Commit

Permalink
♻️ refactor: temperature range from 0 to 2 (lobehub#3355)
Browse files Browse the repository at this point in the history
* feat: Update temperature settings for agent configuration

* fix

* All completed

* temperature / 2

* fix

* fix test

* add test

* ✅ test: fix test

* 🚨 ci: fix lint

---------

Co-authored-by: arvinxx <[email protected]>
  • Loading branch information
sxjeru and arvinxx authored Sep 8, 2024
1 parent 988d744 commit 4a9114e
Show file tree
Hide file tree
Showing 27 changed files with 290 additions and 27 deletions.
2 changes: 1 addition & 1 deletion src/const/settings/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ export const DEFAULT_AGENT_CONFIG: LobeAgentConfig = {
params: {
frequency_penalty: 0,
presence_penalty: 0,
temperature: 0.6,
temperature: 1,
top_p: 1,
},
plugins: [],
Expand Down
2 changes: 1 addition & 1 deletion src/database/client/schemas/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ export const AgentSchema = z.object({
frequency_penalty: z.number().default(0).optional(),
max_tokens: z.number().optional(),
presence_penalty: z.number().default(0).optional(),
temperature: z.number().default(0.6).optional(),
temperature: z.number().default(1).optional(),
top_p: z.number().default(1).optional(),
}),
plugins: z.array(z.string()).optional(),
Expand Down
2 changes: 1 addition & 1 deletion src/features/AgentSetting/AgentModal/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ const AgentModal = memo(() => {
tag: 'model',
},
{
children: <SliderWithInput max={1} min={0} step={0.1} />,
children: <SliderWithInput max={2} min={0} step={0.1} />,
desc: t('settingModel.temperature.desc'),
label: t('settingModel.temperature.title'),
name: ['params', 'temperature'],
Expand Down
2 changes: 1 addition & 1 deletion src/features/ChatInput/ActionBar/Temperature.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ const Temperature = memo(() => {
content={
<SliderWithInput
controls={false}
max={1}
max={2}
min={0}
onChange={(v) => {
updateAgentConfig({ params: { temperature: v } });
Expand Down
4 changes: 2 additions & 2 deletions src/libs/agent-runtime/anthropic/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ describe('LobeAnthropicAI', () => {
messages: [{ content: 'Hello', role: 'user' }],
model: 'claude-3-haiku-20240307',
stream: true,
temperature: 0.5,
temperature: 0.25,
top_p: 1,
},
{},
Expand Down Expand Up @@ -192,7 +192,7 @@ describe('LobeAnthropicAI', () => {
messages: [{ content: 'Hello', role: 'user' }],
model: 'claude-3-haiku-20240307',
stream: true,
temperature: 0.5,
temperature: 0.25,
top_p: 1,
},
{},
Expand Down
5 changes: 4 additions & 1 deletion src/libs/agent-runtime/anthropic/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ export class LobeAnthropicAI implements LobeRuntimeAI {
messages: buildAnthropicMessages(user_messages),
model,
system: system_message?.content as string,
temperature,
temperature:
payload.temperature !== undefined
? temperature / 2
: undefined,
tools: buildAnthropicTools(tools),
top_p,
} satisfies Anthropic.MessageCreateParams;
Expand Down
4 changes: 2 additions & 2 deletions src/libs/agent-runtime/bedrock/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ describe('LobeBedrockAI', () => {
anthropic_version: 'bedrock-2023-05-31',
max_tokens: 2048,
messages: [{ content: 'Hello', role: 'user' }],
temperature: 0.5,
temperature: 0.25,
top_p: 1,
}),
contentType: 'application/json',
Expand Down Expand Up @@ -230,7 +230,7 @@ describe('LobeBedrockAI', () => {
anthropic_version: 'bedrock-2023-05-31',
max_tokens: 2048,
messages: [{ content: 'Hello', role: 'user' }],
temperature: 0.5,
temperature: 0.25,
top_p: 1,
}),
contentType: 'application/json',
Expand Down
2 changes: 1 addition & 1 deletion src/libs/agent-runtime/bedrock/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ export class LobeBedrockAI implements LobeRuntimeAI {
max_tokens: max_tokens || 4096,
messages: buildAnthropicMessages(user_messages),
system: system_message?.content as string,
temperature: temperature,
temperature: temperature / 2,
tools: buildAnthropicTools(tools),
top_p: top_p,
}),
Expand Down
43 changes: 42 additions & 1 deletion src/libs/agent-runtime/groq/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import OpenAI from 'openai';
import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest';

import { ChatStreamCallbacks, LobeOpenAICompatibleRuntime } from '@/libs/agent-runtime';
import { LobeOpenAICompatibleRuntime } from '@/libs/agent-runtime';

import * as debugStreamModule from '../utils/debugStream';
import { LobeGroq } from './index';
Expand Down Expand Up @@ -318,3 +318,44 @@ describe('LobeGroqAI', () => {
});
});
});

describe('LobeGroqAI Temperature Tests', () => {
it('should set temperature to 0.7', async () => {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'mistralai/mistral-7b-instruct:free',
temperature: 0.7,
});

expect(instance['client'].chat.completions.create).toHaveBeenCalledWith(
expect.objectContaining({ temperature: 0.7 }),
expect.anything(),
);
});

it('should set temperature to 0', async () => {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'mistralai/mistral-7b-instruct:free',
temperature: 0,
});

expect(instance['client'].chat.completions.create).toHaveBeenCalledWith(
expect.objectContaining({ temperature: undefined }),
expect.anything(),
);
});

it('should set temperature to negative', async () => {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'mistralai/mistral-7b-instruct:free',
temperature: -1.0,
});

expect(instance['client'].chat.completions.create).toHaveBeenCalledWith(
expect.objectContaining({ temperature: undefined }),
expect.anything(),
);
});
});
5 changes: 4 additions & 1 deletion src/libs/agent-runtime/groq/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@ export const LobeGroq = LobeOpenAICompatibleFactory({
return { error, errorType: AgentRuntimeErrorType.LocationNotSupportError };
},
handlePayload: (payload) => {
const { temperature, ...restPayload } = payload;
return {
...payload,
...restPayload,
// disable stream for tools due to groq dont support
stream: !payload.tools,

temperature: temperature <= 0 ? undefined : temperature,
} as any;
},
},
Expand Down
2 changes: 1 addition & 1 deletion src/libs/agent-runtime/minimax/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ describe('LobeMinimaxAI', () => {
messages: [{ content: 'Hello', role: 'user' }],
model: 'text-davinci-003',
stream: true,
temperature: 0.5,
temperature: 0.25,
top_p: 0.8,
});
});
Expand Down
5 changes: 4 additions & 1 deletion src/libs/agent-runtime/minimax/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,10 @@ export class LobeMinimaxAI implements LobeRuntimeAI {
...params,
max_tokens: this.getMaxTokens(payload.model),
stream: true,
temperature: temperature === 0 ? undefined : temperature,
temperature:
temperature === undefined || temperature <= 0
? undefined
: temperature / 2,

tools: params.tools?.map((tool) => ({
function: {
Expand Down
4 changes: 2 additions & 2 deletions src/libs/agent-runtime/mistral/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ describe('LobeMistralAI', () => {
messages: [{ content: 'Hello', role: 'user' }],
model: 'open-mistral-7b',
stream: true,
temperature: 0.7,
temperature: 0.35,
top_p: 1,
},
{ headers: { Accept: '*/*' } },
Expand Down Expand Up @@ -114,7 +114,7 @@ describe('LobeMistralAI', () => {
messages: [{ content: 'Hello', role: 'user' }],
model: 'open-mistral-7b',
stream: true,
temperature: 0.7,
temperature: 0.35,
top_p: 1,
},
{ headers: { Accept: '*/*' } },
Expand Down
5 changes: 4 additions & 1 deletion src/libs/agent-runtime/mistral/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ export const LobeMistralAI = LobeOpenAICompatibleFactory({
messages: payload.messages as any,
model: payload.model,
stream: true,
temperature: payload.temperature,
temperature:
payload.temperature !== undefined
? payload.temperature / 2
: undefined,
...payload.tools && { tools: payload.tools },
top_p: payload.top_p,
}),
Expand Down
17 changes: 16 additions & 1 deletion src/libs/agent-runtime/moonshot/index.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
import { ModelProvider } from '../types';
import OpenAI from 'openai';

import { ChatStreamPayload, ModelProvider } from '../types';
import { LobeOpenAICompatibleFactory } from '../utils/openaiCompatibleFactory';

export const LobeMoonshotAI = LobeOpenAICompatibleFactory({
baseURL: 'https://api.moonshot.cn/v1',
chatCompletion: {
handlePayload: (payload: ChatStreamPayload) => {
const { temperature, ...rest } = payload;

return {
...rest,
temperature:
temperature !== undefined
? temperature / 2
: undefined,
} as OpenAI.ChatCompletionCreateParamsStreaming;
},
},
debug: {
chatCompletion: () => process.env.DEBUG_MOONSHOT_CHAT_COMPLETION === '1',
},
Expand Down
27 changes: 27 additions & 0 deletions src/libs/agent-runtime/ollama/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,33 @@ describe('LobeOllamaAI', () => {

expect(abortMock).toHaveBeenCalled();
});

it('temperature should be divided by two', async () => {
const chatMock = vi.fn().mockResolvedValue({});
vi.mocked(Ollama.prototype.chat).mockImplementation(chatMock);

const payload = {
messages: [{ content: 'Hello', role: 'user' }],
model: 'model-id',
temperature: 0.7,
};
const options = { signal: new AbortController().signal };

const response = await ollamaAI.chat(payload as any, options);

expect(chatMock).toHaveBeenCalledWith({
messages: [{ content: 'Hello', role: 'user' }],
model: 'model-id',
options: {
frequency_penalty: undefined,
presence_penalty: undefined,
temperature: 0.35,
top_p: undefined,
},
stream: true,
});
expect(response).toBeInstanceOf(Response);
});
});

describe('models', () => {
Expand Down
5 changes: 4 additions & 1 deletion src/libs/agent-runtime/ollama/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ export class LobeOllamaAI implements LobeRuntimeAI {
options: {
frequency_penalty: payload.frequency_penalty,
presence_penalty: payload.presence_penalty,
temperature: payload.temperature,
temperature:
payload.temperature !== undefined
? payload.temperature / 2
: undefined,
top_p: payload.top_p,
},
stream: true,
Expand Down
42 changes: 42 additions & 0 deletions src/libs/agent-runtime/perplexity/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -245,5 +245,47 @@ describe('LobePerplexityAI', () => {
process.env.DEBUG_PERPLEXITY_CHAT_COMPLETION = originalDebugValue;
});
});

it('should call chat method with temperature', async () => {
vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue(
new ReadableStream() as any,
);

await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'text-davinci-003',
temperature: 1.5,
});

expect(instance['client'].chat.completions.create).toHaveBeenCalledWith(
expect.objectContaining({
messages: expect.any(Array),
model: 'text-davinci-003',
temperature: 1.5,
}),
expect.any(Object),
);
});

it('should be undefined when temperature >= 2', async () => {
vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue(
new ReadableStream() as any,
);

await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'text-davinci-003',
temperature: 2,
});

expect(instance['client'].chat.completions.create).toHaveBeenCalledWith(
expect.objectContaining({
messages: expect.any(Array),
model: 'text-davinci-003',
temperature: undefined,
}),
expect.any(Object),
);
});
});
});
9 changes: 7 additions & 2 deletions src/libs/agent-runtime/perplexity/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ export const LobePerplexityAI = LobeOpenAICompatibleFactory({
chatCompletion: {
handlePayload: (payload: ChatStreamPayload) => {
// Set a default frequency penalty value greater than 0
const { presence_penalty, frequency_penalty, stream = true, ...res } = payload;
const { presence_penalty, frequency_penalty, stream = true, temperature, ...res } = payload;

let param;

Expand All @@ -21,7 +21,12 @@ export const LobePerplexityAI = LobeOpenAICompatibleFactory({
param = { frequency_penalty: frequency_penalty || defaultFrequencyPenalty };
}

return { ...res, ...param, stream } as OpenAI.ChatCompletionCreateParamsStreaming;
return {
...res,
...param,
stream,
temperature: temperature >= 2 ? undefined : temperature,
} as OpenAI.ChatCompletionCreateParamsStreaming;
},
},
debug: {
Expand Down
Loading

0 comments on commit 4a9114e

Please sign in to comment.