Skip to content

Commit

Permalink
feat(infer): flatten unnecessary dim in output wavs
Browse files Browse the repository at this point in the history
- optimize(all): use torch.inference_mode instead of with torch.no_grad()
- fix(core): missing params passing to generate
- chore(logging): use warning instead of warn for deprecation
- optimize(gpt): stream triggering timing
- feat(stream): add test for issue 2noise#521
  • Loading branch information
fumiama committed Jul 4, 2024
1 parent 22bf56a commit 0e32ab3
Show file tree
Hide file tree
Showing 8 changed files with 353 additions and 301 deletions.
115 changes: 60 additions & 55 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,33 +162,33 @@ def sample_random_speaker(self) -> str:
return self._encode_spk_emb(self._sample_random_speaker())

@staticmethod
@torch.inference_mode()
def _encode_spk_emb(spk_emb: torch.Tensor) -> str:
with torch.no_grad():
arr: np.ndarray = spk_emb.to(dtype=torch.float16, device="cpu").numpy()
s = b14.encode_to_string(
lzma.compress(
arr.tobytes(),
format=lzma.FORMAT_RAW,
filters=[
{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}
],
),
)
del arr
arr: np.ndarray = spk_emb.to(dtype=torch.float16, device="cpu").numpy()
s = b14.encode_to_string(
lzma.compress(
arr.tobytes(),
format=lzma.FORMAT_RAW,
filters=[
{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}
],
),
)
del arr
return s

@torch.inference_mode()
def _sample_random_speaker(self) -> torch.Tensor:
with torch.no_grad():
dim: int = self.gpt.gpt.layers[0].mlp.gate_proj.in_features
out: torch.Tensor = self.pretrain_models["spk_stat"]
std, mean = out.chunk(2)
spk = (
torch.randn(dim, device=std.device, dtype=torch.float16)
.mul_(std)
.add_(mean)
)
del out, std, mean
return spk
dim: int = self.gpt.gpt.layers[0].mlp.gate_proj.in_features
out: torch.Tensor = self.pretrain_models["spk_stat"]
std, mean = out.chunk(2)
spk = (
torch.randn(dim, device=std.device, dtype=torch.float16)
.mul_(std)
.add_(mean)
)
del out, std, mean
return spk

@dataclass(repr=False, eq=False)
class RefineTextParams:
Expand All @@ -201,6 +201,7 @@ class RefineTextParams:
min_new_token: int = 0
show_tqdm: bool = True
ensure_non_empty: bool = True
stream_batch: int = 24

@dataclass(repr=False, eq=False)
class InferCodeParams(RefineTextParams):
Expand Down Expand Up @@ -343,6 +344,7 @@ def _load(

return self.has_loaded()

@torch.inference_mode()
def _infer(
self,
text,
Expand Down Expand Up @@ -372,60 +374,59 @@ def _infer(
for t in text
]

with torch.no_grad():

if not skip_refine_text:
refined = self._refine_text(
text,
self.device,
params_refine_text,
)
text_tokens = refined.ids
text_tokens = [
i[i.less(self.tokenizer_break_0_ids)] for i in text_tokens
]
text = self.pretrain_models["tokenizer"].batch_decode(text_tokens)
refined.destroy()
if refine_text_only:
yield text
return

length = [0 for _ in range(len(text))]
for result in self._infer_code(
if not skip_refine_text:
refined = self._refine_text(
text,
stream,
self.device,
use_decoder,
params_infer_code,
):
wav = self._decode_to_wavs(result, length, use_decoder)
yield wav
params_refine_text,
)
text_tokens = refined.ids
text_tokens = [
i[i.less(self.tokenizer_break_0_ids)] for i in text_tokens
]
text = self.pretrain_models["tokenizer"].batch_decode(text_tokens)
refined.destroy()
if refine_text_only:
yield text
return

length = [0 for _ in range(len(text))]
for result in self._infer_code(
text,
stream,
self.device,
use_decoder,
params_infer_code,
):
wav = self._decode_to_wavs(result, length, use_decoder)
result.destroy()
yield wav

@torch.inference_mode()
def _vocos_decode(self, spec: torch.Tensor) -> np.ndarray:
if "mps" in str(self.device):
return self.vocos.decode(spec.cpu()).cpu().numpy()
return self.vocos.decode(spec.cpu()).squeeze_(0).cpu().numpy()
else:
return self.vocos.decode(spec).cpu().numpy()
return self.vocos.decode(spec).squeeze_(0).cpu().numpy()

def _decode_to_wavs(
self, result: GPT.GenerationOutputs, start_seeks: List[int], use_decoder: bool
):
x = result.hiddens if use_decoder else result.ids
wavs: List[np.ndarray] = []
wavs: List[Optional[np.ndarray]] = []
for i, chunk_data in enumerate(x):
start_seek = start_seeks[i]
length = len(chunk_data)
if length <= start_seek:
wavs.append(None)
continue
start_seeks[i] = length
chunk_data = chunk_data[start_seek:]
chunk_data = chunk_data[start_seek:].to(self.device)
decoder = self.decoder if use_decoder else self.dvae
mel_spec = decoder(chunk_data[None].permute(0, 2, 1).to(self.device))
mel_spec = decoder(chunk_data.unsqueeze_(0).permute(0, 2, 1))
del chunk_data
wavs.append(self._vocos_decode(mel_spec))
del_all(mel_spec)
result.destroy()
del_all(x)
return wavs

Expand Down Expand Up @@ -590,6 +591,8 @@ def _infer_code(
return_hidden=return_hidden,
stream=stream,
show_tqdm=params.show_tqdm,
ensure_non_empty=params.ensure_non_empty,
stream_batch=params.stream_batch,
context=self.context,
)

Expand Down Expand Up @@ -639,6 +642,8 @@ def _refine_text(
infer_text=True,
stream=False,
show_tqdm=params.show_tqdm,
ensure_non_empty=params.ensure_non_empty,
stream_batch=params.stream_batch,
context=self.context,
)
)
Expand Down
35 changes: 17 additions & 18 deletions ChatTTS/model/dvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,26 +195,25 @@ def __repr__(self) -> str:
self.coef.cpu().numpy().astype(np.float32).tobytes()
)

@torch.inference_mode()
def forward(self, inp: torch.Tensor) -> torch.Tensor:
with torch.no_grad():

if self.vq_layer is not None:
vq_feats = self.vq_layer._embed(inp)
else:
vq_feats = inp
if self.vq_layer is not None:
vq_feats = self.vq_layer._embed(inp)
else:
vq_feats = inp

vq_feats = (
vq_feats.view(
(vq_feats.size(0), 2, vq_feats.size(1) // 2, vq_feats.size(2)),
)
.permute(0, 2, 3, 1)
.flatten(2)
vq_feats = (
vq_feats.view(
(vq_feats.size(0), 2, vq_feats.size(1) // 2, vq_feats.size(2)),
)
.permute(0, 2, 3, 1)
.flatten(2)
)

dec_out = self.out_conv(
self.decoder(
x=vq_feats,
),
)
dec_out = self.out_conv(
self.decoder(
x=vq_feats,
),
)

return torch.mul(dec_out, self.coef, out=dec_out)
return torch.mul(dec_out, self.coef, out=dec_out)
Loading

0 comments on commit 0e32ab3

Please sign in to comment.