Skip to content

Commit

Permalink
fix(tokenizer): apply left padding
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Jul 18, 2024
1 parent 85f0497 commit e69ffac
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 28 deletions.
49 changes: 45 additions & 4 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
import logging
import tempfile
from dataclasses import dataclass, asdict
from typing import Literal, Optional, List, Tuple, Dict
from typing import Literal, Optional, List, Tuple, Dict, Union
from json import load
from pathlib import Path
import lzma

import numpy as np
import torch
import torch.nn.functional as F
from omegaconf import OmegaConf
from vocos import Vocos
from vocos.pretrained import instantiate_class
from huggingface_hub import snapshot_download
Expand Down Expand Up @@ -161,6 +160,28 @@ def unload(self):
def sample_random_speaker(self) -> str:
return self._encode_spk_emb(self._sample_random_speaker())

@torch.inference_mode()
def sample_audio_speaker(self, wav: Union[np.ndarray, torch.Tensor]) -> str:
if isinstance(wav, np.ndarray):
wav = torch.from_numpy(wav)
return self._encode_prompt(self.dvae(wav, "encode").squeeze_(0))

@staticmethod
@torch.no_grad()
def _encode_prompt(prompt: torch.Tensor) -> str:
arr: np.ndarray = prompt.to(dtype=torch.uint16, device="cpu").numpy()
shp = arr.shape
assert len(shp) == 2, "prompt must be a 2D tensor"
s = b14.encode_to_string(
np.array(shp, dtype="<u2").tobytes() + lzma.compress(
arr.astype("<u2").tobytes(),
format=lzma.FORMAT_RAW,
filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
),
)
del arr
return s

@staticmethod
@torch.no_grad()
def _encode_spk_emb(spk_emb: torch.Tensor) -> str:
Expand Down Expand Up @@ -201,6 +222,7 @@ class RefineTextParams:
class InferCodeParams(RefineTextParams):
prompt: str = "[speed_5]"
spk_emb: Optional[str] = None
sample: Optional[str]=None
temperature: float = 0.3
repetition_penalty: float = 1.05
max_new_token: int = 2048
Expand Down Expand Up @@ -459,6 +481,22 @@ def _decode_to_wavs(
del mel_specs
return wavs

@staticmethod
@torch.no_grad()
def _decode_prompt(prompt: str) -> torch.Tensor:
dec = b14.decode_from_string(prompt)
shp = np.frombuffer(dec[:4], dtype="<u2")
p = np.frombuffer(
lzma.decompress(
dec[4:],
format=lzma.FORMAT_RAW,
filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
),
dtype="<u2",
).copy()
del dec
return torch.from_numpy(p).view(*shp)

@staticmethod
def _decode_spk_emb(spk_emb: str) -> np.ndarray:
return np.frombuffer(
Expand Down Expand Up @@ -540,7 +578,10 @@ def _infer_code(
text = [f"[Stts][empty_spk]{i}[Ptts]" for i in text]

input_ids, attention_mask, text_mask = self.tokenizer.encode(
text, self.gpt.num_vq, gpt.device_gpt
text,
self.gpt.num_vq,
prompt=self._decode_prompt(params.sample) if params.sample is not None else None,
device=gpt.device_gpt,
)

emb = gpt(input_ids, text_mask)
Expand Down Expand Up @@ -600,7 +641,7 @@ def _refine_text(
text = [f"[Sbreak]{i}[Pbreak]{params.prompt}" for i in text]

input_ids, attention_mask, text_mask = self.tokenizer.encode(
text, self.gpt.num_vq, gpt.device_gpt
text, self.gpt.num_vq, device=gpt.device_gpt,
)

logits_warpers, logits_processors = gen_logits(
Expand Down
26 changes: 14 additions & 12 deletions ChatTTS/model/dvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,24 +95,22 @@ def _embed(self, x: torch.Tensor):
feat = self.quantizer.get_output_from_indices(x)
return feat.transpose_(1, 2) if self.transpose else feat

def __call__(
self, x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
def __call__(self, x: torch.Tensor) -> torch.Tensor:
return super().__call__(x)

def forward(
self, x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.transpose:
x.transpose_(1, 2)
feat, ind = self.quantizer(x)
# feat, ind = self.quantizer(x)
_, ind = self.quantizer(x)
"""
ind = rearrange(
ind, "g b t r ->b t (g r)",
)
"""
ind = ind.permute(1, 2, 0, 3).contiguous()
ind = ind.view(ind.size(0), ind.size(1), -1)
"""
embed_onehot_tmp = F.one_hot(ind.long(), self.n_ind)
embed_onehot = embed_onehot_tmp.to(x.dtype)
del embed_onehot_tmp
Expand All @@ -121,12 +119,12 @@ def forward(
torch.div(e_mean, (e_mean.sum(dim=1) + self.eps).unsqueeze(1), out=e_mean)
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1))
return (
return
torch.zeros(perplexity.shape, dtype=x.dtype, device=x.device),
feat.transpose_(1, 2) if self.transpose else feat,
perplexity,
ind.transpose_(1, 2) if self.transpose else ind,
)
"""
return ind.transpose_(1, 2) if self.transpose else ind


class DVAEDecoder(nn.Module):
Expand Down Expand Up @@ -255,9 +253,13 @@ def forward(
) -> torch.Tensor:
if mode == "encode" and hasattr(self, "encoder") and self.vq_layer is not None:
mel = self.preprocessor_mel(inp)
x: torch.Tensor = self.downsample_conv(mel / self.coef)
x: torch.Tensor = self.downsample_conv(
torch.div(mel, self.coef.view(100, 1).expand(mel.shape), out=mel),
).unsqueeze_(0)
del mel
x = self.encoder(x)
ind = self.vq_layer(x)[3]
ind = self.vq_layer(x)
del x
return ind

if self.vq_layer is not None:
Expand Down
51 changes: 41 additions & 10 deletions ChatTTS/model/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning
"""

from typing import List, Tuple
from typing import List, Tuple, Optional

import torch
from transformers import BertTokenizerFast
Expand All @@ -31,13 +31,19 @@ def __init__(

@torch.inference_mode()
def encode(
self, text: List[str], num_vq: int, device="cpu"
self, text: List[str], num_vq: int, prompt: Optional[torch.Tensor]=None, device="cpu",
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

input_ids_lst = []
attention_mask_lst = []
max_input_ids_len = -1
max_attention_mask_len = -1
prompt_size = 0

if prompt is not None:
assert prompt.size(0) == num_vq, "prompt dim 0 must equal to num_vq"
prompt_size = prompt.size(1)

# avoid random speaker embedding of tokenizer in the other dims
for t in text:
x = self._tokenizer.encode_plus(
Expand All @@ -52,30 +58,55 @@ def encode(
attn_sz = attention_mask_lst[-1].size(0)
if attn_sz > max_attention_mask_len:
max_attention_mask_len = attn_sz

if prompt is not None:
max_input_ids_len += prompt_size
max_attention_mask_len += prompt_size

input_ids = torch.zeros(
len(input_ids_lst),
max_input_ids_len,
device=device,
dtype=input_ids_lst[0].dtype,
)
for i in range(len(input_ids_lst)):
input_ids.narrow(0, i, 1).narrow(1, 0, input_ids_lst[i].size(0)).copy_(
input_ids_lst[i]
)
input_ids.narrow(0, i, 1).narrow(
1,
max_input_ids_len-prompt_size-input_ids_lst[i].size(0),
input_ids_lst[i].size(0),
).copy_(input_ids_lst[i]) # left padding
del_all(input_ids_lst)

attention_mask = torch.zeros(
len(attention_mask_lst),
max_attention_mask_len,
device=device,
dtype=attention_mask_lst[0].dtype,
)
for i in range(len(attention_mask_lst)):
attention_mask.narrow(0, i, 1).narrow(
1, 0, attention_mask_lst[i].size(0)
).copy_(attention_mask_lst[i])
attn = attention_mask.narrow(0, i, 1)
attn.narrow(
1,
max_attention_mask_len-prompt_size-attention_mask_lst[i].size(0),
attention_mask_lst[i].size(0),
).copy_(attention_mask_lst[i]) # left padding
if prompt_size > 0:
attn.narrow(
1, max_attention_mask_len-prompt_size, prompt_size,
).fill_(1)
del_all(attention_mask_lst)

text_mask = torch.ones(input_ids.shape, dtype=bool, device=device)
input_ids = input_ids.unsqueeze_(-1).expand(-1, -1, num_vq)
new_input_ids = input_ids.unsqueeze_(-1).expand(-1, -1, num_vq).clone()

if prompt_size > 0:
text_mask.narrow(1, max_input_ids_len-prompt_size, prompt_size).fill_(0)
prompt_t = prompt.t().unsqueeze_(0).expand(input_ids.size(0), -1, -1)
new_input_ids.narrow(
1,
max_input_ids_len-prompt_size,
prompt_size,
).copy_(prompt_t)
del prompt_t

return input_ids, attention_mask, text_mask
return new_input_ids, attention_mask, text_mask
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
numpy<2.0.0
numba
omegaconf>=2.3.0
torch>=2.1.0
torchaudio
tqdm
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
install_requires=[
"numba",
"numpy<2.0.0",
"omegaconf>=2.3.0",
"pybase16384",
"torch>=2.1.0",
"torchaudio",
Expand Down

0 comments on commit e69ffac

Please sign in to comment.