Skip to content

Commit

Permalink
feat: add telemetry origin (#141)
Browse files Browse the repository at this point in the history
* feat: anonymised telemetry to track usage patterns

* add: PR suggestions

* feat: add telemetry origin field

* feat: fix POST definition

---------

Co-authored-by: Siddharth Sharma <[email protected]>
Co-authored-by: siddharth sharma <[email protected]>
  • Loading branch information
3 people authored May 3, 2024
1 parent eebdcc6 commit dab7454
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 30 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ docker-compose up -d ui && docker-compose ps && docker-compose logs -f

Server
```bash
# navigate to <URL>/docs for API definitions
docker-compose up -d server && docker-compose ps && docker-compose logs -f
```

Expand Down Expand Up @@ -102,7 +103,10 @@ tts.synthesise(text="This is a demo of text to speech by MetaVoice-1B, an open-s
```bash
# You can use `--quantisation_mode int4` or `--quantisation_mode int8` for experimental faster inference. This will degrade the quality of the audio.
# Note: int8 is slower than bf16/fp16 for undebugged reasons. If you want fast, try int4 which is roughly 2x faster than bf16/fp16.

# navigate to <URL>/docs for API definitions
poetry run python serving.py

poetry run python app.py
```

Expand Down
4 changes: 2 additions & 2 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
from fam.llm.utils import check_audio_file

#### setup model
TTS_MODEL = tyro.cli(TTS)
TTS_MODEL = tyro.cli(TTS, args=["--telemetry_origin", "webapp"])

#### setup interface
RADIO_CHOICES = ["Preset voices", "Upload target voice (atleast 30s)"]
MAX_CHARS = 220
PRESET_VOICES = {
# female
"Bria": "https://cdn.themetavoice.xyz/speakers%2Fbria.mp3",
"Bria": "https://cdn.themetavoice.xyz/speakers/bria.mp3",
# male
"Alex": "https://cdn.themetavoice.xyz/speakers/alex.mp3",
"Jacob": "https://cdn.themetavoice.xyz/speakers/jacob.wav",
Expand Down
4 changes: 4 additions & 0 deletions fam/llm/fast_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
output_dir: str = "outputs",
quantisation_mode: Optional[Literal["int4", "int8"]] = None,
first_stage_path: Optional[str] = None,
telemetry_origin: Optional[str] = None,
):
"""
Initialise the TTS model.
Expand All @@ -60,6 +61,7 @@ def __init__(
- int4 for int4 weight-only quantisation,
- int8 for int8 weight-only quantisation.
first_stage_path: path to first-stage LLM checkpoint. If provided, this will override the one grabbed from Hugging Face via `model_name`.
telemetry_origin: A string identifier that specifies the origin of the telemetry data sent to PostHog.
"""

# NOTE: this needs to come first so that we don't change global state when we want to use
Expand Down Expand Up @@ -104,6 +106,7 @@ def __init__(
self._seed = seed
self._quantisation_mode = quantisation_mode
self._model_name = model_name
self._telemetry_origin = telemetry_origin

def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.0, temperature=1.0) -> str:
"""
Expand Down Expand Up @@ -183,6 +186,7 @@ def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.
"seed": self._seed,
"first_stage_ckpt": self._first_stage_ckpt,
"gpu": torch.cuda.get_device_name(0),
"telemetry_origin": self._telemetry_origin,
},
)
)
Expand Down
59 changes: 31 additions & 28 deletions serving.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import logging
import shlex
import subprocess
Expand All @@ -12,7 +11,7 @@
import tyro
import uvicorn
from attr import dataclass
from fastapi import Request
from fastapi import File, Form, HTTPException, UploadFile, status
from fastapi.responses import Response

from fam.llm.fast_inference import TTS
Expand Down Expand Up @@ -50,55 +49,55 @@ class _GlobalState:
GlobalState = _GlobalState()


@dataclass(frozen=True)
class TTSRequest:
text: str
speaker_ref_path: Optional[str] = None
guidance: float = 3.0
top_p: float = 0.95
top_k: Optional[int] = None


@app.get("/health")
async def health_check():
return {"status": "ok"}


@app.post("/tts", response_class=Response)
async def text_to_speech(req: Request):
audiodata = await req.body()
payload = None
async def text_to_speech(
text: str = Form(...),
speaker_ref_path: Optional[str] = Form(None),
guidance: float = Form(3.0),
top_p: float = Form(0.95),
audiodata: Optional[UploadFile] = File(None),
):
# Ensure at least one of speaker_ref_path or audiodata is provided
if not audiodata and not speaker_ref_path:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Either an audio file or a speaker reference path must be provided.",
)

wav_out_path = None

try:
headers = req.headers
payload = headers["X-Payload"]
payload = json.loads(payload)
tts_req = TTSRequest(**payload)
with tempfile.NamedTemporaryFile(suffix=".wav") as wav_tmp:
if tts_req.speaker_ref_path is None:
if speaker_ref_path is None:
wav_path = _convert_audiodata_to_wav_path(audiodata, wav_tmp)
check_audio_file(wav_path)
else:
# TODO: fix
wav_path = tts_req.speaker_ref_path
wav_path = speaker_ref_path

if wav_path is None:
warnings.warn("Running without speaker reference")
assert tts_req.guidance is None
assert guidance is None

wav_out_path = GlobalState.tts.synthesise(
text=tts_req.text,
text=text,
spk_ref_path=wav_path,
top_p=tts_req.top_p,
guidance_scale=tts_req.guidance,
top_p=top_p,
guidance_scale=guidance,
)

with open(wav_out_path, "rb") as f:
return Response(content=f.read(), media_type="audio/wav")
except Exception as e:
# traceback_str = "".join(traceback.format_tb(e.__traceback__))
logger.exception(f"Error processing request {payload}")
logger.exception(
f"Error processing request. text: {text}, speaker_ref_path: {speaker_ref_path}, guidance: {guidance}, top_p: {top_p}"
)
return Response(
content="Something went wrong. Please try again in a few mins or contact us on Discord",
status_code=500,
Expand All @@ -108,9 +107,9 @@ async def text_to_speech(req: Request):
Path(wav_out_path).unlink(missing_ok=True)


def _convert_audiodata_to_wav_path(audiodata, wav_tmp):
def _convert_audiodata_to_wav_path(audiodata: UploadFile, wav_tmp):
with tempfile.NamedTemporaryFile() as unknown_format_tmp:
if unknown_format_tmp.write(audiodata) == 0:
if unknown_format_tmp.write(audiodata.read()) == 0:
return None
unknown_format_tmp.flush()

Expand All @@ -129,7 +128,11 @@ def _convert_audiodata_to_wav_path(audiodata, wav_tmp):
logging.root.setLevel(logging.INFO)

GlobalState.config = tyro.cli(ServingConfig)
GlobalState.tts = TTS(seed=GlobalState.config.seed, quantisation_mode=GlobalState.config.quantisation_mode)
GlobalState.tts = TTS(
seed=GlobalState.config.seed,
quantisation_mode=GlobalState.config.quantisation_mode,
telemetry_origin="api_server",
)

app.add_middleware(
fastapi.middleware.cors.CORSMiddleware,
Expand Down

0 comments on commit dab7454

Please sign in to comment.