Skip to content

Commit

Permalink
optimize(core): move _text_to_token into tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Jul 5, 2024
1 parent f1c6da9 commit 9db2c87
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 56 deletions.
57 changes: 3 additions & 54 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def _infer(
)
text_tokens = refined.ids
text_tokens = [i[i.less(self.tokenizer.break_0_ids)] for i in text_tokens]
text = self.tokenizer.batch_decode(text_tokens)
text = self.tokenizer.decode(text_tokens)
refined.destroy()
if refine_text_only:
yield text
Expand Down Expand Up @@ -415,57 +415,6 @@ def _decode_to_wavs(
del_all(x)
return wavs

@torch.inference_mode()
def _text_to_token(
self, text: List[str], device="cpu"
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

input_ids_lst = []
attention_mask_lst = []
max_input_ids_len = -1
max_attention_mask_len = -1
# avoid random speaker embedding of tokenizer in the other dims
for t in text:
x = self.tokenizer.batch_encode(
t, return_tensors="pt", add_special_tokens=False, padding=True
)
input_ids_lst.append(x["input_ids"].squeeze_(0))
attention_mask_lst.append(x["attention_mask"].squeeze_(0))
del_all(x)
ids_sz = input_ids_lst[-1].size(0)
if ids_sz > max_input_ids_len:
max_input_ids_len = ids_sz
attn_sz = attention_mask_lst[-1].size(0)
if attn_sz > max_attention_mask_len:
max_attention_mask_len = attn_sz
input_ids = torch.zeros(
len(input_ids_lst),
max_input_ids_len,
device=device,
dtype=input_ids_lst[0].dtype,
)
for i in range(len(input_ids_lst)):
input_ids.narrow(0, i, 1).narrow(1, 0, input_ids_lst[i].size(0)).copy_(
input_ids_lst[i]
)
del_all(input_ids_lst)
attention_mask = torch.zeros(
len(attention_mask_lst),
max_attention_mask_len,
device=device,
dtype=attention_mask_lst[0].dtype,
)
for i in range(len(attention_mask_lst)):
attention_mask.narrow(0, i, 1).narrow(
1, 0, attention_mask_lst[i].size(0)
).copy_(attention_mask_lst[i])
del_all(attention_mask_lst)

text_mask = torch.ones(input_ids.shape, dtype=bool, device=device)
input_ids = input_ids.unsqueeze_(-1).expand(-1, -1, self.gpt.num_vq)

return input_ids, attention_mask, text_mask

@staticmethod
def _decode_spk_emb(spk_emb: str) -> np.ndarray:
return np.frombuffer(
Expand Down Expand Up @@ -546,7 +495,7 @@ def _infer_code(
else:
text = [f"[Stts][empty_spk]{i}[Ptts]" for i in text]

input_ids, attention_mask, text_mask = self._text_to_token(text, gpt.device_gpt)
input_ids, attention_mask, text_mask = self.tokenizer.encode(text, self.gpt.num_vq, gpt.device_gpt)

emb = gpt(input_ids, text_mask)

Expand Down Expand Up @@ -604,7 +553,7 @@ def _refine_text(

text = [f"[Sbreak]{i}[Pbreak]{params.prompt}" for i in text]

input_ids, attention_mask, text_mask = self._text_to_token(text, gpt.device_gpt)
input_ids, attention_mask, text_mask = self.tokenizer.encode(text, self.gpt.num_vq, gpt.device_gpt)

logits_warpers, logits_processors = gen_logits(
num_code=self.tokenizer.len,
Expand Down
58 changes: 56 additions & 2 deletions ChatTTS/model/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import List, Tuple

import torch
from transformers import BertTokenizerFast

from ..utils import del_all


class Tokenizer:
def __init__(
Expand All @@ -17,5 +21,55 @@ def __init__(
self.break_0_ids = tokenizer.convert_tokens_to_ids("[break_0]")
self.eos_token = tokenizer.convert_tokens_to_ids("[Ebreak]")

self.batch_encode = self._tokenizer.__call__
self.batch_decode = self._tokenizer.batch_decode
self.decode = self._tokenizer.batch_decode

@torch.inference_mode()
def encode(
self, text: List[str], num_vq:int, device="cpu"
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

input_ids_lst = []
attention_mask_lst = []
max_input_ids_len = -1
max_attention_mask_len = -1
# avoid random speaker embedding of tokenizer in the other dims
for t in text:
x = self._tokenizer(
t, return_tensors="pt", add_special_tokens=False, padding=True
)
input_ids_lst.append(x["input_ids"].squeeze_(0))
attention_mask_lst.append(x["attention_mask"].squeeze_(0))
del_all(x)
ids_sz = input_ids_lst[-1].size(0)
if ids_sz > max_input_ids_len:
max_input_ids_len = ids_sz
attn_sz = attention_mask_lst[-1].size(0)
if attn_sz > max_attention_mask_len:
max_attention_mask_len = attn_sz
input_ids = torch.zeros(
len(input_ids_lst),
max_input_ids_len,
device=device,
dtype=input_ids_lst[0].dtype,
)
for i in range(len(input_ids_lst)):
input_ids.narrow(0, i, 1).narrow(1, 0, input_ids_lst[i].size(0)).copy_(
input_ids_lst[i]
)
del_all(input_ids_lst)
attention_mask = torch.zeros(
len(attention_mask_lst),
max_attention_mask_len,
device=device,
dtype=attention_mask_lst[0].dtype,
)
for i in range(len(attention_mask_lst)):
attention_mask.narrow(0, i, 1).narrow(
1, 0, attention_mask_lst[i].size(0)
).copy_(attention_mask_lst[i])
del_all(attention_mask_lst)

text_mask = torch.ones(input_ids.shape, dtype=bool, device=device)
input_ids = input_ids.unsqueeze_(-1).expand(-1, -1, num_vq)

return input_ids, attention_mask, text_mask

0 comments on commit 9db2c87

Please sign in to comment.