diff --git a/ChatTTS/core.py b/ChatTTS/core.py index f6c7f5aeb..3715f6970 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -2,7 +2,7 @@ import logging import tempfile from dataclasses import dataclass, asdict -from typing import Literal, Optional, List, Tuple, Dict +from typing import Literal, Optional, List, Tuple, Dict, Union from json import load from pathlib import Path import lzma @@ -10,7 +10,6 @@ import numpy as np import torch import torch.nn.functional as F -from omegaconf import OmegaConf from vocos import Vocos from vocos.pretrained import instantiate_class from huggingface_hub import snapshot_download @@ -161,6 +160,28 @@ def unload(self): def sample_random_speaker(self) -> str: return self._encode_spk_emb(self._sample_random_speaker()) + @torch.inference_mode() + def sample_audio_speaker(self, wav: Union[np.ndarray, torch.Tensor]) -> str: + if isinstance(wav, np.ndarray): + wav = torch.from_numpy(wav) + return self._encode_prompt(self.dvae(wav, "encode").squeeze_(0)) + + @staticmethod + @torch.no_grad() + def _encode_prompt(prompt: torch.Tensor) -> str: + arr: np.ndarray = prompt.to(dtype=torch.uint16, device="cpu").numpy() + shp = arr.shape + assert len(shp) == 2, "prompt must be a 2D tensor" + s = b14.encode_to_string( + np.array(shp, dtype=" str: @@ -201,6 +222,7 @@ class RefineTextParams: class InferCodeParams(RefineTextParams): prompt: str = "[speed_5]" spk_emb: Optional[str] = None + sample: Optional[str]=None temperature: float = 0.3 repetition_penalty: float = 1.05 max_new_token: int = 2048 @@ -459,6 +481,22 @@ def _decode_to_wavs( del mel_specs return wavs + @staticmethod + @torch.no_grad() + def _decode_prompt(prompt: str) -> torch.Tensor: + dec = b14.decode_from_string(prompt) + shp = np.frombuffer(dec[:4], dtype=" np.ndarray: return np.frombuffer( @@ -540,7 +578,10 @@ def _infer_code( text = [f"[Stts][empty_spk]{i}[Ptts]" for i in text] input_ids, attention_mask, text_mask = self.tokenizer.encode( - text, self.gpt.num_vq, gpt.device_gpt + text, + self.gpt.num_vq, + prompt=self._decode_prompt(params.sample) if params.sample is not None else None, + device=gpt.device_gpt, ) emb = gpt(input_ids, text_mask) @@ -600,7 +641,7 @@ def _refine_text( text = [f"[Sbreak]{i}[Pbreak]{params.prompt}" for i in text] input_ids, attention_mask, text_mask = self.tokenizer.encode( - text, self.gpt.num_vq, gpt.device_gpt + text, self.gpt.num_vq, device=gpt.device_gpt, ) logits_warpers, logits_processors = gen_logits( diff --git a/ChatTTS/model/dvae.py b/ChatTTS/model/dvae.py index a5cb5a0fc..ff00980d8 100644 --- a/ChatTTS/model/dvae.py +++ b/ChatTTS/model/dvae.py @@ -95,17 +95,14 @@ def _embed(self, x: torch.Tensor): feat = self.quantizer.get_output_from_indices(x) return feat.transpose_(1, 2) if self.transpose else feat - def __call__( - self, x: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def __call__(self, x: torch.Tensor) -> torch.Tensor: return super().__call__(x) - def forward( - self, x: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def forward(self, x: torch.Tensor) -> torch.Tensor: if self.transpose: x.transpose_(1, 2) - feat, ind = self.quantizer(x) + # feat, ind = self.quantizer(x) + _, ind = self.quantizer(x) """ ind = rearrange( ind, "g b t r ->b t (g r)", @@ -113,6 +110,7 @@ def forward( """ ind = ind.permute(1, 2, 0, 3).contiguous() ind = ind.view(ind.size(0), ind.size(1), -1) + """ embed_onehot_tmp = F.one_hot(ind.long(), self.n_ind) embed_onehot = embed_onehot_tmp.to(x.dtype) del embed_onehot_tmp @@ -121,12 +119,12 @@ def forward( torch.div(e_mean, (e_mean.sum(dim=1) + self.eps).unsqueeze(1), out=e_mean) perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1)) - return ( + return torch.zeros(perplexity.shape, dtype=x.dtype, device=x.device), feat.transpose_(1, 2) if self.transpose else feat, perplexity, - ind.transpose_(1, 2) if self.transpose else ind, - ) + """ + return ind.transpose_(1, 2) if self.transpose else ind class DVAEDecoder(nn.Module): @@ -255,9 +253,13 @@ def forward( ) -> torch.Tensor: if mode == "encode" and hasattr(self, "encoder") and self.vq_layer is not None: mel = self.preprocessor_mel(inp) - x: torch.Tensor = self.downsample_conv(mel / self.coef) + x: torch.Tensor = self.downsample_conv( + torch.div(mel, self.coef.view(100, 1).expand(mel.shape), out=mel), + ).unsqueeze_(0) + del mel x = self.encoder(x) - ind = self.vq_layer(x)[3] + ind = self.vq_layer(x) + del x return ind if self.vq_layer is not None: diff --git a/ChatTTS/model/tokenizer.py b/ChatTTS/model/tokenizer.py index c90e71236..d949fc713 100644 --- a/ChatTTS/model/tokenizer.py +++ b/ChatTTS/model/tokenizer.py @@ -5,7 +5,7 @@ https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning """ -from typing import List, Tuple +from typing import List, Tuple, Optional import torch from transformers import BertTokenizerFast @@ -31,13 +31,19 @@ def __init__( @torch.inference_mode() def encode( - self, text: List[str], num_vq: int, device="cpu" + self, text: List[str], num_vq: int, prompt: Optional[torch.Tensor]=None, device="cpu", ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: input_ids_lst = [] attention_mask_lst = [] max_input_ids_len = -1 max_attention_mask_len = -1 + prompt_size = 0 + + if prompt is not None: + assert prompt.size(0) == num_vq, "prompt dim 0 must equal to num_vq" + prompt_size = prompt.size(1) + # avoid random speaker embedding of tokenizer in the other dims for t in text: x = self._tokenizer.encode_plus( @@ -52,6 +58,11 @@ def encode( attn_sz = attention_mask_lst[-1].size(0) if attn_sz > max_attention_mask_len: max_attention_mask_len = attn_sz + + if prompt is not None: + max_input_ids_len += prompt_size + max_attention_mask_len += prompt_size + input_ids = torch.zeros( len(input_ids_lst), max_input_ids_len, @@ -59,10 +70,13 @@ def encode( dtype=input_ids_lst[0].dtype, ) for i in range(len(input_ids_lst)): - input_ids.narrow(0, i, 1).narrow(1, 0, input_ids_lst[i].size(0)).copy_( - input_ids_lst[i] - ) + input_ids.narrow(0, i, 1).narrow( + 1, + max_input_ids_len-prompt_size-input_ids_lst[i].size(0), + input_ids_lst[i].size(0), + ).copy_(input_ids_lst[i]) # left padding del_all(input_ids_lst) + attention_mask = torch.zeros( len(attention_mask_lst), max_attention_mask_len, @@ -70,12 +84,29 @@ def encode( dtype=attention_mask_lst[0].dtype, ) for i in range(len(attention_mask_lst)): - attention_mask.narrow(0, i, 1).narrow( - 1, 0, attention_mask_lst[i].size(0) - ).copy_(attention_mask_lst[i]) + attn = attention_mask.narrow(0, i, 1) + attn.narrow( + 1, + max_attention_mask_len-prompt_size-attention_mask_lst[i].size(0), + attention_mask_lst[i].size(0), + ).copy_(attention_mask_lst[i]) # left padding + if prompt_size > 0: + attn.narrow( + 1, max_attention_mask_len-prompt_size, prompt_size, + ).fill_(1) del_all(attention_mask_lst) text_mask = torch.ones(input_ids.shape, dtype=bool, device=device) - input_ids = input_ids.unsqueeze_(-1).expand(-1, -1, num_vq) + new_input_ids = input_ids.unsqueeze_(-1).expand(-1, -1, num_vq).clone() + + if prompt_size > 0: + text_mask.narrow(1, max_input_ids_len-prompt_size, prompt_size).fill_(0) + prompt_t = prompt.t().unsqueeze_(0).expand(input_ids.size(0), -1, -1) + new_input_ids.narrow( + 1, + max_input_ids_len-prompt_size, + prompt_size, + ).copy_(prompt_t) + del prompt_t - return input_ids, attention_mask, text_mask + return new_input_ids, attention_mask, text_mask diff --git a/requirements.txt b/requirements.txt index 7578f4273..75066bb96 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ numpy<2.0.0 numba -omegaconf>=2.3.0 torch>=2.1.0 torchaudio tqdm diff --git a/setup.py b/setup.py index 7c47b4ba5..da7b609e6 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,6 @@ install_requires=[ "numba", "numpy<2.0.0", - "omegaconf>=2.3.0", "pybase16384", "torch>=2.1.0", "torchaudio",