Skip to content

Commit

Permalink
feat(streaming): add PlayDialog engine/model support (#43)
Browse files Browse the repository at this point in the history
  • Loading branch information
acdcjunior authored Dec 11, 2024
1 parent dd0465b commit 8018159
Show file tree
Hide file tree
Showing 13 changed files with 368 additions and 156 deletions.
1 change: 1 addition & 0 deletions packages/playht/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
test-output-*.mp3
3 changes: 2 additions & 1 deletion packages/playht/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "playht",
"version": "0.13.0",
"version": "0.14.0",
"description": "NodeJS SDK for PlayHT generative AI text-to-speech APIs",
"files": [
"dist/**/*",
Expand All @@ -22,6 +22,7 @@
"verify": "yarn check && yarn test",
"check": "yarn build:protobufs && tsc -p tsconfig.json --noEmit && prettier --check . && eslint --ext .ts ./src",
"release": "yarn && yarn verify && yarn build && cp ../../README.md . && npm publish || true && rm README.md",
"release-alpha": "yarn && yarn verify && yarn build && cp ../../README.md . && npm publish --tag=alpha || true && rm README.md",
"postpublish": "PACKAGE_VERSION=$(cat package.json | grep \\\"version\\\" | head -1 | awk -F: '{ print $2 }' | sed 's/[\",]//g' | tr -d '[[:space:]]') && git tag v$PACKAGE_VERSION && git push --tags"
},
"devDependencies": {
Expand Down
92 changes: 92 additions & 0 deletions packages/playht/src/__tests__/e2eStreaming.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import { buffer } from 'node:stream/consumers';
import fs from 'node:fs';
import { describe, expect, it } from '@jest/globals';
import * as PlayHT from '../index';
import { E2E_CONFIG } from './e2eTestConfig';

describe('E2E Streaming', () => {
describe('Play3.0-mini', () => {
it('streams from text', async () => {
PlayHT.init({
userId: E2E_CONFIG.USER_ID,
apiKey: E2E_CONFIG.API_KEY,
});

const streamFromText = await PlayHT.stream('Hello from SDK test.', {
voiceEngine: 'Play3.0-mini',
// @ts-expect-error emotion is not part of the Play3.0-mini contract
emotion: 'female_surprised',
outputFormat: 'mp3',
});

const audioBuffer = await buffer(streamFromText);
fs.writeFileSync('test-output-Play3.0-mini.mp3', audioBuffer); // for debugging

expect(audioBuffer.length).toBeGreaterThan(30_000); // errors would result in smaller payloads
expect(audioBuffer.toString('ascii')).toContain('ID3');
});
});

describe('PlayDialog', () => {
it('streams from text', async () => {
PlayHT.init({
userId: E2E_CONFIG.USER_ID,
apiKey: E2E_CONFIG.API_KEY,
});

const streamFromText = await PlayHT.stream('Host 1: Is this the SDK?\nHost 2: Yes, it is.', {
voiceEngine: 'PlayDialog',
outputFormat: 'mp3',
temperature: 1.2,
quality: 'high',
voiceId2: 's3://voice-cloning-zero-shot/775ae416-49bb-4fb6-bd45-740f205d20a1/jennifersaad/manifest.json',
turnPrefix: 'Host 1:',
turnPrefix2: 'Host 2:',
language: 'english',

// @ts-expect-error emotion and language are not part of the PlayDialog contract
emotion: 'female_surprised',
styleGuidance: 16,
});

const audioBuffer = await buffer(streamFromText);
fs.writeFileSync('test-output-PlayDialog.mp3', audioBuffer); // for debugging

expect(audioBuffer.length).toBeGreaterThan(30_000); // errors would result in smaller payloads
expect(audioBuffer.toString('ascii')).toContain('ID3');
}, 120_000);
});

describe('PlayDialogMultilingual', () => {
it('streams from text', async () => {
PlayHT.init({
userId: E2E_CONFIG.USER_ID,
apiKey: E2E_CONFIG.API_KEY,
});

const streamFromText = await PlayHT.stream(
'Host 1: Estamos todos prontos para fazer o que for necessário aqui. Host 2: É impossível esquecer tudo que vivemos.',
{
voiceEngine: 'PlayDialog',
outputFormat: 'mp3',
temperature: 1.2,
quality: 'high',
voiceId2: 's3://voice-cloning-zero-shot/775ae416-49bb-4fb6-bd45-740f205d20a1/jennifersaad/manifest.json',
turnPrefix: 'Host 1:',
turnPrefix2: 'Host 2:',
language: 'portuguese',

// @ts-expect-error emotion and language are not part of the PlayDialog contract
emotion: 'female_surprised',
styleGuidance: 16,
},
);

const audioBuffer = await buffer(streamFromText);
fs.writeFileSync('test-output-PlayDialogMultilingual.mp3', audioBuffer); // for debugging

expect(audioBuffer.length).toBeGreaterThan(30_000); // errors would result in smaller payloads
expect(audioBuffer.toString('ascii')).toContain('ID3');
}, 120_000);
});
});
11 changes: 6 additions & 5 deletions packages/playht/src/api/apiCommon.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import type {
PlayHT20OutputStreamFormat,
Play30EngineStreamOptions,
OutputFormat,
PlayDialogEngineStreamOptions,
} from '..';
import { PassThrough, Readable, Writable } from 'node:stream';
import { APISettingsStore } from './APISettingsStore';
Expand All @@ -18,7 +19,7 @@ import { generateV2Speech } from './generateV2Speech';
import { generateV2Stream } from './generateV2Stream';
import { textStreamToSentences } from './textStreamToSentences';
import { generateGRpcStream } from './generateGRpcStream';
import { generateV3Stream } from './internal/tts/v3/generateV3Stream';
import { generateAuthBasedStream } from './internal/tts/v3/generateAuthBasedStream';
import { PlayRequestConfig } from './internal/config/PlayRequestConfig';

export type V1ApiOptions = {
Expand All @@ -43,8 +44,7 @@ export type V2ApiOptions = {
textGuidance?: number;
};

export type V3ApiOptions = Pick<Play30EngineStreamOptions, 'language' | 'voiceEngine'> &
Omit<V2ApiOptions, 'voiceEngine' | 'emotion'>;
export type AuthBasedEngineOptions = Play30EngineStreamOptions | PlayDialogEngineStreamOptions;

type Preset = 'real-time' | 'balanced' | 'low-latency' | 'high-quality';

Expand Down Expand Up @@ -104,8 +104,9 @@ export async function internalGenerateStreamFromString(
const v2Options = toV2Options(options, true);
return await generateGRpcStream(input, options.voiceId, v2Options);
}
case 'Play3.0-mini': {
return await generateV3Stream(input, options.voiceId, options, reqConfig);
case 'Play3.0-mini':
case 'PlayDialog': {
return await generateAuthBasedStream(input, options.voiceId, options, reqConfig);
}
}
}
Expand Down
16 changes: 15 additions & 1 deletion packages/playht/src/api/internal/tts/v3/V3InternalSettings.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
/**
* "Public" because these are the engines the users can choose from.
*/
export type PublicAuthBasedEngine = 'Play3.0-mini' | 'PlayDialog';

/**
* "Internal" because these are the engines we use internally to determine the inference address (the HTTP endpoint).
*/
export type InternalAuthBasedEngine = PublicAuthBasedEngine | 'PlayDialogMultilingual';

export type V3InternalSettings = {
// how much time before expiration should we refresh the coordinates
coordinatesExpirationAdvanceRefreshTimeMs?: number;
// refresh no more frequently than this
coordinatesExpirationMinimalFrequencyMs?: number;
// number of attempts when calling API to get new coordinates
coordinatesGetApiCallMaxRetries?: number;
customInferenceCoordinatesGenerator?: (userId: string, apiKey: string) => Promise<InferenceCoordinatesEntry>;
customInferenceCoordinatesGenerator?: (
engine: InternalAuthBasedEngine,
userId: string,
apiKey: string,
) => Promise<InferenceCoordinatesEntry>;
};

export type InferenceCoordinatesEntry = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,29 @@ import axios, { AxiosRequestConfig } from 'axios';
import { keepAliveHttpsAgent } from '../../http';
import { PlayRequestConfig } from '../../config/PlayRequestConfig';
import { createOrGetInferenceAddress } from './createOrGetInferenceAddress';
import { InternalAuthBasedEngine, PublicAuthBasedEngine } from './V3InternalSettings';

export const backgroundWarmUpAuthBasedEngine = (reqConfigSettings: PlayRequestConfig['settings']) => {
warmUp(reqConfigSettings).catch((error: any) => {
// eslint-disable-next-line no-process-env
console.log(`[PlayHT SDK] Error while warming up SDK: ${error.message}`, process.env.DEBUG ? error : '');
});
export const backgroundWarmUpAuthBasedEngine = (
selectedEngine: PublicAuthBasedEngine,
reqConfigSettings: PlayRequestConfig['settings'],
) => {
const engines =
selectedEngine === 'Play3.0-mini'
? (['Play3.0-mini'] as const)
: (['PlayDialog', 'PlayDialogMultilingual'] as const);
for (const engine of engines) {
warmUp(engine, reqConfigSettings).catch((error: any) => {
console.log(
`[PlayHT SDK] Error while warming up SDK (${engine}): ${error.message}`,
// eslint-disable-next-line no-process-env
process.env.DEBUG ? error : '',
);
});
}
};

const warmUp = async (reqConfigSettings: PlayRequestConfig['settings']) => {
const inferenceAddress = await createOrGetInferenceAddress(reqConfigSettings);
const warmUp = async (engine: InternalAuthBasedEngine, reqConfigSettings: PlayRequestConfig['settings']) => {
const inferenceAddress = await createOrGetInferenceAddress(engine, reqConfigSettings);
const streamOptions: AxiosRequestConfig = {
method: 'OPTIONS',
url: inferenceAddress,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
import { describe, expect } from '@jest/globals';
import { beforeEach, describe, expect } from '@jest/globals';
import { createOrGetInferenceAddress } from './createOrGetInferenceAddress';
import { InternalAuthBasedEngine } from './V3InternalSettings';

async function sleep(timeout: number) {
await new Promise((resolve) => setTimeout(resolve, timeout));
}

describe('createOrGetInferenceAddress', () => {
let callSequenceNumber = 0;
let callSequenceNumber: number;
beforeEach(() => {
callSequenceNumber = 0;
});
const reqConfigSettings = (userId: string) => ({
userId,
apiKey: 'test',
apiKey: 'test-api-key',
experimental: {
v3: {
customInferenceCoordinatesGenerator: async () => {
customInferenceCoordinatesGenerator: async (_: InternalAuthBasedEngine, u: string) => {
await sleep(10); // simulate a delay
return {
inferenceAddress: `call ${userId} #${++callSequenceNumber}`,
inferenceAddress: `call ${u} #${++callSequenceNumber}`,
expiresAtMs: Date.now() + 1_000_000,
};
},
Expand All @@ -29,7 +33,7 @@ describe('createOrGetInferenceAddress', () => {
it('serializes concurrent calls for the same user', async () => {
const numberOfTestCalls = 15;
const calls = Array.from({ length: numberOfTestCalls }, () =>
createOrGetInferenceAddress(reqConfigSettings('test-user')),
createOrGetInferenceAddress('Play3.0-mini', reqConfigSettings('test-user')),
);

// Expect all calls to return 'call #1', not 'call #1', 'call #2', 'call #3', etc.
Expand All @@ -39,19 +43,19 @@ describe('createOrGetInferenceAddress', () => {
it('doesnt serialize calls for different users', async () => {
const numberOfTestCalls = 3;
const callsOne = Array.from({ length: numberOfTestCalls }, (_, i) =>
createOrGetInferenceAddress(reqConfigSettings(`test-user#${i}`)),
createOrGetInferenceAddress('Play3.0-mini', reqConfigSettings(`test-user#${i}`)),
);
const callsTwo = Array.from({ length: numberOfTestCalls }, (_, i) =>
createOrGetInferenceAddress(reqConfigSettings(`test-user#${i}`)),
createOrGetInferenceAddress('Play3.0-mini', reqConfigSettings(`test-user#${i}`)),
);

expect(await Promise.all([...callsOne, ...callsTwo])).toEqual([
'call test-user#0 #2',
'call test-user#1 #3',
'call test-user#2 #4',
'call test-user#0 #2',
'call test-user#1 #3',
'call test-user#2 #4',
'call test-user#0 #1',
'call test-user#1 #2',
'call test-user#2 #3',
'call test-user#0 #1',
'call test-user#1 #2',
'call test-user#2 #3',
]);
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,24 @@ import { keepAliveHttpsAgent } from '../../http';
import { PlayRequestConfig } from '../../config/PlayRequestConfig';
import { APISettingsStore } from '../../../APISettingsStore';
import { UserId } from '../../types';
import { InferenceCoordinatesEntry } from './V3InternalSettings';
import { InternalAuthBasedEngine, InferenceCoordinatesEntry, V3InternalSettings } from './V3InternalSettings';
import { V3_DEFAULT_SETTINGS } from './V3DefaultSettings';

const inferenceCoordinatesStore: Record<UserId, InferenceCoordinatesEntry> = {};
const inferenceCoordinatesStores: Record<InternalAuthBasedEngine, Record<UserId, InferenceCoordinatesEntry>> = {
'Play3.0-mini': {},
PlayDialog: {},
PlayDialogMultilingual: {},
};

// By default, the inference coordinates generator will call the Play API to get the inference coordinates.
const defaultInferenceCoordinatesGenerator = async (
userId: string,
apiKey: string,
const defaultInferenceCoordinatesGenerator: V3InternalSettings['customInferenceCoordinatesGenerator'] = async (
engine,
userId,
apiKey,
): Promise<InferenceCoordinatesEntry> => {
const data = await axios
.post(
'https://api.play.ht/api/v3/auth',
'https://api.play.ht/api/v3/auth?dialog',
{},
{
headers: {
Expand All @@ -28,19 +33,23 @@ const defaultInferenceCoordinatesGenerator = async (
)
.then(
(response) =>
response.data as {
inference_address: string;
response.data as Record<InternalAuthBasedEngine, { http_streaming_url: string; websocket_url: string }> & {
expires_at_ms: number;
},
)
.catch((error: any) => convertError(error));
const httpStreamingUrl = data[engine]?.http_streaming_url;
if (!httpStreamingUrl) {
return convertError(new Error(`Engine ${engine} not found in AUTH response`));
}
return {
inferenceAddress: data.inference_address,
inferenceAddress: httpStreamingUrl,
expiresAtMs: data.expires_at_ms,
};
};

const createInferenceCoordinates = async (
voiceEngine: InternalAuthBasedEngine,
reqConfigSettings?: PlayRequestConfig['settings'],
attemptNo = 0,
): Promise<InferenceCoordinatesEntry> => {
Expand All @@ -64,13 +73,13 @@ const createInferenceCoordinates = async (
V3_DEFAULT_SETTINGS.coordinatesGetApiCallMaxRetries;

try {
const newInferenceCoordinatesEntry = await inferenceCoordinatesGenerator(userId, apiKey);
const newInferenceCoordinatesEntry = await inferenceCoordinatesGenerator(voiceEngine, userId, apiKey);
const automaticRefreshDelay = Math.max(
coordinatesExpirationMinimalFrequencyMs,
newInferenceCoordinatesEntry.expiresAtMs - Date.now() - coordinatesExpirationAdvanceRefreshTimeMs,
);
setTimeout(() => createInferenceCoordinates(reqConfigSettings), automaticRefreshDelay).unref();
inferenceCoordinatesStore[userId] = newInferenceCoordinatesEntry;
setTimeout(() => createInferenceCoordinates(voiceEngine, reqConfigSettings), automaticRefreshDelay).unref();
inferenceCoordinatesStores[voiceEngine][userId] = newInferenceCoordinatesEntry;
return newInferenceCoordinatesEntry;
} catch (e) {
if (attemptNo >= coordinatesGetApiCallMaxRetries) {
Expand All @@ -79,7 +88,7 @@ const createInferenceCoordinates = async (
return new Promise((resolve) => {
setTimeout(
() => {
resolve(createInferenceCoordinates(reqConfigSettings, attemptNo + 1));
resolve(createInferenceCoordinates(voiceEngine, reqConfigSettings, attemptNo + 1));
},
500 * (attemptNo + 1),
).unref();
Expand All @@ -90,15 +99,16 @@ const createInferenceCoordinates = async (
const inferenceCoordinatesCreationPromise: Record<UserId, Promise<InferenceCoordinatesEntry>> = {};

export const createOrGetInferenceAddress = async (
voiceEngine: InternalAuthBasedEngine,
reqConfigSettings?: PlayRequestConfig['settings'],
): Promise<string> => {
const userId = (reqConfigSettings?.userId ?? APISettingsStore.getSettings().userId) as UserId;
const inferenceCoordinatesEntry = inferenceCoordinatesStore[userId];
const inferenceCoordinatesEntry = inferenceCoordinatesStores[voiceEngine][userId];
if (inferenceCoordinatesEntry && inferenceCoordinatesEntry.expiresAtMs >= Date.now() - 5_000) {
return inferenceCoordinatesEntry.inferenceAddress;
} else {
if (!(userId in inferenceCoordinatesCreationPromise)) {
inferenceCoordinatesCreationPromise[userId] = createInferenceCoordinates(reqConfigSettings);
inferenceCoordinatesCreationPromise[userId] = createInferenceCoordinates(voiceEngine, reqConfigSettings);
}
const newInferenceCoordinatesEntry = (await inferenceCoordinatesCreationPromise[userId])!;
delete inferenceCoordinatesCreationPromise[userId];
Expand Down
Loading

0 comments on commit 8018159

Please sign in to comment.