Skip to content

Commit

Permalink
feat (provider/openai-compatible): Add 'apiKey' option for concise di…
Browse files Browse the repository at this point in the history
…rect use. (vercel#4293)
  • Loading branch information
shaper authored Jan 7, 2025
1 parent 3ac8cce commit 92ac806
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ import { createOpenAICompatible } from '@ai-sdk/openai-compatible';

const perplexity = createOpenAICompatible({
name: 'perplexity',
headers: {
Authorization: `Bearer ${process.env.PERPLEXITY_API_KEY}`,
},
apiKey: process.env.PERPLEXITY_API_KEY,
baseURL: 'https://api.perplexity.ai/',
});
```
Expand All @@ -60,9 +58,7 @@ import { generateText } from 'ai';

const perplexity = createOpenAICompatible({
name: 'perplexity',
headers: {
Authorization: `Bearer ${process.env.PERPLEXITY_API_KEY}`,
},
apiKey: process.env.PERPLEXITY_API_KEY,
baseURL: 'https://api.perplexity.ai/',
});

Expand Down
12 changes: 3 additions & 9 deletions content/providers/02-openai-compatible-providers/40-baseten.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ const basetenExtraPayload = {

const baseten = createOpenAICompatible({
name: 'baseten',
headers: {
Authorization: `Bearer ${process.env.BASETEN_API_KEY}`,
},
apiKey: process.env.BASETEN_API_KEY,
baseURL: 'https://bridge.baseten.co/v1/direct',
fetch: async (url, request) => {
const bodyWithBasetenPayload = JSON.stringify({
Expand Down Expand Up @@ -87,9 +85,7 @@ const basetenExtraPayload = {

const baseten = createOpenAICompatible({
name: 'baseten',
headers: {
Authorization: `Bearer ${process.env.BASETEN_API_KEY}`,
},
apiKey: process.env.BASETEN_API_KEY,
baseURL: 'https://bridge.baseten.co/v1/direct',
fetch: async (url, request) => {
const bodyWithBasetenPayload = JSON.stringify({
Expand Down Expand Up @@ -125,9 +121,7 @@ const basetenExtraPayload = {

const baseten = createOpenAICompatible({
name: 'baseten',
headers: {
Authorization: `Bearer ${process.env.BASETEN_API_KEY}`,
},
apiKey: process.env.BASETEN_API_KEY,
baseURL: 'https://bridge.baseten.co/v1/direct',
fetch: async (url, request) => {
const bodyWithBasetenPayload = JSON.stringify({
Expand Down
12 changes: 3 additions & 9 deletions content/providers/02-openai-compatible-providers/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ import { createOpenAICompatible } from '@ai-sdk/openai-compatible';

const provider = createOpenAICompatible({
name: 'provider-name',
headers: {
Authorization: `Bearer ${process.env.PROVIDER_API_KEY}`,
},
apiKey: process.env.PROVIDER_API_KEY,
baseURL: 'https://api.provider.com/v1',
});
```
Expand All @@ -68,9 +66,7 @@ import { generateText } from 'ai'

const provider = createOpenAICompatible({
name: 'provider-name',
headers: {
Authorization: `Bearer ${process.env.PROVIDER_API_KEY}`,
},
apiKey: process.env.PROVIDER_API_KEY,
baseURL: 'https://api.provider.com/v1'
})

Expand Down Expand Up @@ -107,9 +103,7 @@ const model = createOpenAICompatible<
ExampleEmbeddingModelIds
>({
name: 'example',
headers: {
Authorization: `Bearer ${process.env.MY_API_KEY}`,
},
apiKey: process.env.PROVIDER_API_KEY,
baseURL: 'https://api.example.com/v1',
});

Expand Down
22 changes: 19 additions & 3 deletions packages/openai-compatible/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,24 @@ import { createOpenAICompatible } from '@ai-sdk/openai-compatible';
import { createOpenAICompatible } from '@ai-sdk/openai-compatible';
import { generateText } from 'ai';

const { text } = await generateText({
model: createOpenAICompatible({
baseURL: 'https://api.example.com/v1',
name: 'example',
apiKey: process.env.MY_API_KEY,
}).chatModel('meta-llama/Llama-3-70b-chat-hf'),
prompt: 'Write a vegetarian lasagna recipe for 4 people.',
});
```

### Customizing headers

You can further customize headers if desired. For example, here is an alternate implementation to pass along api key authentication:

```ts
import { createOpenAICompatible } from '@ai-sdk/openai-compatible';
import { generateText } from 'ai';

const { text } = await generateText({
model: createOpenAICompatible({
baseURL: 'https://api.example.com/v1',
Expand Down Expand Up @@ -66,9 +84,7 @@ const model = createOpenAICompatible<
>({
baseURL: 'https://api.example.com/v1',
name: 'example',
headers: {
Authorization: `Bearer ${process.env.MY_API_KEY}`,
},
apiKey: process.env.MY_API_KEY,
});

// Subsequent calls to e.g. `model.chatModel` will auto-complete the model id
Expand Down
181 changes: 181 additions & 0 deletions packages/openai-compatible/src/openai-compatible-provider.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { createOpenAICompatible } from './openai-compatible-provider';
import { OpenAICompatibleChatLanguageModel } from './openai-compatible-chat-language-model';
import { OpenAICompatibleCompletionLanguageModel } from './openai-compatible-completion-language-model';
import { OpenAICompatibleEmbeddingModel } from './openai-compatible-embedding-model';
import { OpenAICompatibleChatSettings } from './openai-compatible-chat-settings';

const OpenAICompatibleChatLanguageModelMock = vi.mocked(
OpenAICompatibleChatLanguageModel,
);
const OpenAICompatibleCompletionLanguageModelMock = vi.mocked(
OpenAICompatibleCompletionLanguageModel,
);
const OpenAICompatibleEmbeddingModelMock = vi.mocked(
OpenAICompatibleEmbeddingModel,
);

vi.mock('./openai-compatible-chat-language-model', () => ({
OpenAICompatibleChatLanguageModel: vi.fn(),
}));

vi.mock('./openai-compatible-completion-language-model', () => ({
OpenAICompatibleCompletionLanguageModel: vi.fn(),
}));

vi.mock('./openai-compatible-embedding-model', () => ({
OpenAICompatibleEmbeddingModel: vi.fn(),
}));

describe('OpenAICompatibleProvider', () => {
beforeEach(() => {
vi.clearAllMocks();
});

describe('createOpenAICompatible', () => {
it('should throw error if baseURL is not provided', () => {
expect(() => createOpenAICompatible({ name: 'test-provider' })).toThrow(
'Base URL is required',
);
});

it('should throw error if name is not provided', () => {
expect(() =>
createOpenAICompatible({ baseURL: 'https://api.example.com' }),
).toThrow('Provider name is required');
});

it('should create provider with correct configuration', () => {
const options = {
baseURL: 'https://api.example.com',
name: 'test-provider',
apiKey: 'test-api-key',
headers: { 'Custom-Header': 'value' },
};

const provider = createOpenAICompatible(options);
const model = provider('model-id');

const constructorCall =
OpenAICompatibleChatLanguageModelMock.mock.calls[0];
const config = constructorCall[2];
const headers = config.headers();

expect(headers).toEqual({
Authorization: 'Bearer test-api-key',
'Custom-Header': 'value',
});
expect(config.provider).toBe('test-provider.chat');
expect(config.url({ modelId: 'model-id', path: '/v1/chat' })).toBe(
'https://api.example.com/v1/chat',
);
});

it('should create headers without Authorization when no apiKey provided', () => {
const options = {
baseURL: 'https://api.example.com',
name: 'test-provider',
headers: { 'Custom-Header': 'value' },
};

const provider = createOpenAICompatible(options);
provider('model-id');

const constructorCall =
OpenAICompatibleChatLanguageModelMock.mock.calls[0];
const config = constructorCall[2];
const headers = config.headers();

expect(headers).toEqual({
'Custom-Header': 'value',
});
});
});

describe('model creation methods', () => {
const defaultOptions = {
baseURL: 'https://api.example.com',
name: 'test-provider',
apiKey: 'test-api-key',
headers: { 'Custom-Header': 'value' },
};

it('should create chat model with correct configuration', () => {
const provider = createOpenAICompatible(defaultOptions);
const settings: OpenAICompatibleChatSettings = {};

provider.chatModel('chat-model', settings);

const constructorCall =
OpenAICompatibleChatLanguageModelMock.mock.calls[0];
const config = constructorCall[2];
const headers = config.headers();

expect(headers).toEqual({
Authorization: 'Bearer test-api-key',
'Custom-Header': 'value',
});
expect(config.provider).toBe('test-provider.chat');
expect(config.url({ modelId: 'model-id', path: '/v1/chat' })).toBe(
'https://api.example.com/v1/chat',
);
});

it('should create completion model with correct configuration', () => {
const provider = createOpenAICompatible(defaultOptions);
const settings: OpenAICompatibleChatSettings = {};

provider.completionModel('completion-model', settings);

const constructorCall =
OpenAICompatibleCompletionLanguageModelMock.mock.calls[0];
const config = constructorCall[2];
const headers = config.headers();

expect(headers).toEqual({
Authorization: 'Bearer test-api-key',
'Custom-Header': 'value',
});
expect(config.provider).toBe('test-provider.completion');
expect(
config.url({ modelId: 'completion-model', path: '/v1/completions' }),
).toBe('https://api.example.com/v1/completions');
});

it('should create embedding model with correct configuration', () => {
const provider = createOpenAICompatible(defaultOptions);
const settings: OpenAICompatibleChatSettings = {};

provider.textEmbeddingModel('embedding-model', settings);

const constructorCall = OpenAICompatibleEmbeddingModelMock.mock.calls[0];
const config = constructorCall[2];
const headers = config.headers();

expect(headers).toEqual({
Authorization: 'Bearer test-api-key',
'Custom-Header': 'value',
});
expect(config.provider).toBe('test-provider.embedding');
expect(
config.url({ modelId: 'embedding-model', path: '/v1/embeddings' }),
).toBe('https://api.example.com/v1/embeddings');
});

it('should use languageModel as default when called as function', () => {
const provider = createOpenAICompatible(defaultOptions);
const settings: OpenAICompatibleChatSettings = {};

provider('model-id', settings);

expect(OpenAICompatibleChatLanguageModel).toHaveBeenCalledWith(
'model-id',
settings,
expect.objectContaining({
provider: 'test-provider.chat',
defaultObjectGenerationMode: 'tool',
}),
);
});
});
});
31 changes: 24 additions & 7 deletions packages/openai-compatible/src/openai-compatible-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@ import {
LanguageModelV1,
ProviderV1,
} from '@ai-sdk/provider';
import { FetchFunction, withoutTrailingSlash } from '@ai-sdk/provider-utils';
import {
FetchFunction,
loadApiKey,
withoutTrailingSlash,
} from '@ai-sdk/provider-utils';
import { OpenAICompatibleChatLanguageModel } from './openai-compatible-chat-language-model';
import { OpenAICompatibleChatSettings } from './openai-compatible-chat-settings';
import { OpenAICompatibleCompletionLanguageModel } from './openai-compatible-completion-language-model';
Expand Down Expand Up @@ -44,23 +48,31 @@ export interface OpenAICompatibleProvider<

export interface OpenAICompatibleProviderSettings {
/**
Base URL for the API calls.
Base URL for the API calls.
*/
baseURL?: string;

/**
Custom headers to include in the requests.
API key for authenticating requests. If specified, adds an `Authorization`
header to request headers with the value `Bearer <apiKey>`. This will be added
before any headers potentially specified in the `headers` option.
*/
apiKey?: string;

/**
Optional custom headers to include in requests. These will be added to request headers
after any headers potentially added by use of the `apiKey` option.
*/
headers?: Record<string, string>;

/**
Custom fetch implementation. You can use it as a middleware to intercept requests,
or to provide a custom fetch implementation for e.g. testing.
Custom fetch implementation. You can use it as a middleware to intercept requests,
or to provide a custom fetch implementation for e.g. testing.
*/
fetch?: FetchFunction;

/**
Provider name.
Provider name.
*/
name?: string;
}
Expand Down Expand Up @@ -96,10 +108,15 @@ export function createOpenAICompatible<
fetch?: FetchFunction;
}

const getHeaders = () => ({
...(options.apiKey && { Authorization: `Bearer ${options.apiKey}` }),
...options.headers,
});

const getCommonModelConfig = (modelType: string): CommonModelConfig => ({
provider: `${providerName}.${modelType}`,
url: ({ path }) => `${baseURL}${path}`,
headers: () => options.headers ?? {},
headers: getHeaders,
fetch: options.fetch,
});

Expand Down

0 comments on commit 92ac806

Please sign in to comment.