Skip to content

Commit

Permalink
Merge pull request #10 from jeasonstudio/feat-stream-text-support-abo…
Browse files Browse the repository at this point in the history
…rt-signal

feat: stream text/object support abort signal
  • Loading branch information
jeasonstudio authored Jul 5, 2024
2 parents f813465 + c999022 commit d9c3ce3
Show file tree
Hide file tree
Showing 10 changed files with 138 additions and 44 deletions.
5 changes: 5 additions & 0 deletions .changeset/polite-forks-whisper.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"chrome-ai": patch
---

feat: stream text/object support abort signal
11 changes: 11 additions & 0 deletions src/chromeai.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import { describe, it, expect } from 'vitest';
import { chromeai } from './chromeai';

describe('chromeai', () => {
it('should correctly create instance', async () => {
expect(chromeai().modelId).toBe('generic');
expect(chromeai('text').modelId).toBe('text');
expect(chromeai('embedding').modelId).toBe('embedding');
expect(chromeai.embedding().modelId).toBe('embedding');
});
});
40 changes: 40 additions & 0 deletions src/chromeai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import {
ChromeAIEmbeddingModel,
ChromeAIEmbeddingModelSettings,
} from './embedding-model';
import {
ChromeAIChatLanguageModel,
ChromeAIChatModelId,
ChromeAIChatSettings,
} from './language-model';
import createDebug from 'debug';

const debug = createDebug('chromeai');

/**
* Create a new ChromeAI model/embedding instance.
* @param modelId 'generic' | 'text' | 'embedding'
* @param settings Options for the model
*/
export function chromeai(
modelId?: ChromeAIChatModelId,
settings?: ChromeAIChatSettings
): ChromeAIChatLanguageModel;
export function chromeai(
modelId?: 'embedding',
settings?: ChromeAIEmbeddingModelSettings
): ChromeAIEmbeddingModel;
export function chromeai(modelId: string = 'generic', settings: any = {}) {
debug('create instance', modelId, settings);
if (modelId === 'embedding') {
return new ChromeAIEmbeddingModel(settings);
}
return new ChromeAIChatLanguageModel(
modelId as ChromeAIChatModelId,
settings
);
}

/** @deprecated use `chromeai('embedding'[, options])` */
chromeai.embedding = (settings: ChromeAIEmbeddingModelSettings = {}) =>
new ChromeAIEmbeddingModel(settings);
8 changes: 4 additions & 4 deletions src/embedding-model.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { describe, it, expect, vi } from 'vitest';
import { ChromeAIEmbeddingModel, chromeEmbedding } from './embedding-model';
import { ChromeAIEmbeddingModel } from './embedding-model';
import { embed } from 'ai';

vi.mock('@mediapipe/tasks-text', async () => ({
Expand All @@ -23,10 +23,10 @@ vi.mock('@mediapipe/tasks-text', async () => ({
describe('embedding-model', () => {
it('should instantiation anyways', async () => {
expect(new ChromeAIEmbeddingModel()).toBeInstanceOf(ChromeAIEmbeddingModel);
expect(chromeEmbedding()).toBeInstanceOf(ChromeAIEmbeddingModel);
expect(new ChromeAIEmbeddingModel()).toBeInstanceOf(ChromeAIEmbeddingModel);
});
it('should embed', async () => {
const model = chromeEmbedding();
const model = new ChromeAIEmbeddingModel();
expect(
await embed({
model,
Expand All @@ -45,7 +45,7 @@ describe('embedding-model', () => {
it('should embed result empty', async () => {
expect(
await embed({
model: chromeEmbedding({ l2Normalize: true }),
model: new ChromeAIEmbeddingModel({ l2Normalize: true }),
value: 'undefined',
})
).toMatchObject({ embedding: [] });
Expand Down
21 changes: 9 additions & 12 deletions src/embedding-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,13 @@ export interface ChromeAIEmbeddingModelSettings {
delegate?: 'CPU' | 'GPU';
}

// See more:
// - https://github.com/google-ai-edge/mediapipe
// - https://ai.google.dev/edge/mediapipe/solutions/text/text_embedder/web_js
export class ChromeAIEmbeddingModel implements EmbeddingModelV1<string> {
readonly specificationVersion = 'v1';
readonly provider = 'google-mediapipe';
readonly modelId: string = 'mediapipe';
readonly modelId: string = 'embedding';
readonly supportsParallelCalls = true;
readonly maxEmbeddingsPerCall = undefined;

Expand Down Expand Up @@ -80,18 +83,12 @@ export class ChromeAIEmbeddingModel implements EmbeddingModelV1<string> {
rawResponse?: Record<PropertyKey, any>;
}> => {
// if (options.abortSignal) console.warn('abortSignal is not supported');

const embedder = await this.getTextEmbedder();
const embeddings = await Promise.all(
options.values.map((text) => {
const embedderResult = embedder.embed(text);
const [embedding] = embedderResult.embeddings;
return embedding?.floatEmbedding ?? [];
})
);
const embeddings = options.values.map((text) => {
const embedderResult = embedder.embed(text);
const [embedding] = embedderResult.embeddings;
return embedding?.floatEmbedding ?? [];
});
return { embeddings };
};
}

export const chromeEmbedding = (options?: ChromeAIEmbeddingModelSettings) =>
new ChromeAIEmbeddingModel(options);
1 change: 1 addition & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
export * from './language-model';
export * from './embedding-model';
export * from './chromeai';
52 changes: 36 additions & 16 deletions src/language-model.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { describe, it, expect, vi, afterEach } from 'vitest';
import { chromeai, ChromeAIChatLanguageModel } from './index';
import { ChromeAIChatLanguageModel } from './index';
import { generateText, streamText, generateObject, streamObject } from 'ai';
import {
LoadSettingError,
Expand All @@ -14,17 +14,22 @@ describe('language-model', () => {
});

it('should instantiation anyways', () => {
expect(chromeai()).toBeInstanceOf(ChromeAIChatLanguageModel);
expect(chromeai().modelId).toBe('generic');
expect(chromeai('text').modelId).toBe('text');
expect(new ChromeAIChatLanguageModel('generic')).toBeInstanceOf(
ChromeAIChatLanguageModel
);
expect(new ChromeAIChatLanguageModel('text').modelId).toBe('text');
expect(
chromeai('text', { temperature: 1, topK: 10 }).options
new ChromeAIChatLanguageModel('text', { temperature: 1, topK: 10 })
.options
).toMatchObject({ temperature: 1, topK: 10 });
});

it('should throw when not support', async () => {
await expect(() =>
generateText({ model: chromeai(), prompt: 'empty' })
generateText({
model: new ChromeAIChatLanguageModel('generic'),
prompt: 'empty',
})
).rejects.toThrowError(LoadSettingError);

const cannotCreateSession = vi.fn(async () => 'no');
Expand All @@ -34,12 +39,18 @@ describe('language-model', () => {
});

await expect(() =>
generateText({ model: chromeai('text'), prompt: 'empty' })
generateText({
model: new ChromeAIChatLanguageModel('text'),
prompt: 'empty',
})
).rejects.toThrowError(LoadSettingError);
expect(cannotCreateSession).toHaveBeenCalledTimes(1);

await expect(() =>
generateText({ model: chromeai('generic'), prompt: 'empty' })
generateText({
model: new ChromeAIChatLanguageModel('generic'),
prompt: 'empty',
})
).rejects.toThrowError(LoadSettingError);
expect(cannotCreateSession).toHaveBeenCalledTimes(2);
});
Expand All @@ -58,17 +69,23 @@ describe('language-model', () => {
createTextSession: createSession,
});

await generateText({ model: chromeai('text'), prompt: 'test' });
await generateText({
model: new ChromeAIChatLanguageModel('text'),
prompt: 'test',
});
expect(getOptions).toHaveBeenCalledTimes(1);

const result = await generateText({ model: chromeai(), prompt: 'test' });
const result = await generateText({
model: new ChromeAIChatLanguageModel('generic'),
prompt: 'test',
});
expect(result).toMatchObject({
finishReason: 'stop',
text: 'test',
});

const resultForMessages = await generateText({
model: chromeai(),
model: new ChromeAIChatLanguageModel('generic'),
messages: [
{ role: 'user', content: 'test' },
{ role: 'assistant', content: 'assistant' },
Expand All @@ -95,7 +112,10 @@ describe('language-model', () => {
createGenericSession: vi.fn(async () => ({ promptStreaming })),
});

const result = await streamText({ model: chromeai(), prompt: 'test' });
const result = await streamText({
model: new ChromeAIChatLanguageModel('generic'),
prompt: 'test',
});
for await (const textPart of result.textStream) {
expect(textPart).toBe('test');
}
Expand All @@ -110,7 +130,7 @@ describe('language-model', () => {
});

const { object } = await generateObject({
model: chromeai(),
model: new ChromeAIChatLanguageModel('generic'),
schema: z.object({
hello: z.string(),
}),
Expand All @@ -129,7 +149,7 @@ describe('language-model', () => {
});
await expect(() =>
generateText({
model: chromeai(),
model: new ChromeAIChatLanguageModel('generic'),
messages: [
{
role: 'tool',
Expand Down Expand Up @@ -161,7 +181,7 @@ describe('language-model', () => {

await expect(() =>
generateObject({
model: chromeai(),
model: new ChromeAIChatLanguageModel('generic'),
mode: 'grammar',
schema: z.object({}),
prompt: 'test',
Expand All @@ -170,7 +190,7 @@ describe('language-model', () => {

await expect(() =>
streamObject({
model: chromeai(),
model: new ChromeAIChatLanguageModel('generic'),
mode: 'grammar',
schema: z.object({}),
prompt: 'test',
Expand Down
14 changes: 5 additions & 9 deletions src/language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ import {
import { ChromeAISession, ChromeAISessionOptions } from './global';
import createDebug from 'debug';
import { StreamAI } from './stream-ai';
import { chromeEmbedding } from './embedding-model';
import {
ChromeAIEmbeddingModel,
ChromeAIEmbeddingModelSettings,
} from './embedding-model';

const debug = createDebug('chromeai');

Expand Down Expand Up @@ -224,7 +227,7 @@ export class ChromeAIChatLanguageModel implements LanguageModelV1 {
const session = await this.getSession();
const message = this.formatMessages(options);
const promptStream = session.promptStreaming(message);
const transformStream = new StreamAI();
const transformStream = new StreamAI(options.abortSignal);
const stream = promptStream.pipeThrough(transformStream);

return {
Expand All @@ -233,10 +236,3 @@ export class ChromeAIChatLanguageModel implements LanguageModelV1 {
};
};
}

export const chromeai = (
modelId: ChromeAIChatModelId = 'generic',
settings: ChromeAIChatSettings = {}
) => new ChromeAIChatLanguageModel(modelId, settings);

chromeai.embedding = chromeEmbedding;
20 changes: 19 additions & 1 deletion src/stream-ai.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { describe, it, expect, vi, afterEach } from 'vitest';
import { describe, it, expect } from 'vitest';
import { StreamAI } from './stream-ai';

describe('stream-ai', () => {
Expand All @@ -23,4 +23,22 @@ describe('stream-ai', () => {
value: { type: 'finish' },
});
});

it('should abort when signal', async () => {
const controller = new AbortController();
const transformStream = new StreamAI(controller.signal);

const writer = transformStream.writable.getWriter();
const reader = transformStream.readable.getReader();

writer.write('hello');

expect(await reader.read()).toMatchObject({
value: { type: 'text-delta', textDelta: 'hello' },
done: false,
});

controller.abort();
expect(await reader.read()).toMatchObject({ done: true });
});
});
10 changes: 8 additions & 2 deletions src/stream-ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,17 @@ export class StreamAI extends TransformStream<
string,
LanguageModelV1StreamPart
> {
public constructor() {
public constructor(abortSignal?: AbortSignal) {
let textTemp = '';
super({
start: () => {
start: (controller) => {
textTemp = '';
if (!abortSignal) return;
abortSignal.addEventListener('abort', () => {
debug('streamText terminate by abortSignal');
controller.terminate();
textTemp = '';
});
},
transform: (chunk, controller) => {
const textDelta = chunk.replace(textTemp, '');
Expand Down

0 comments on commit d9c3ce3

Please sign in to comment.