Skip to content

Commit

Permalink
Fix: cover-pt
Browse files Browse the repository at this point in the history
  • Loading branch information
jianchang512 committed Aug 9, 2024
1 parent 92f57cc commit 99889a0
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
14 changes: 14 additions & 0 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,20 @@ def _infer(
# Filter both rows and columns using slicing
yield new_wavs[:][:, keep_cols]

@staticmethod
@torch.no_grad()
def _encode_spk_emb(spk_emb: torch.Tensor) -> str:
arr: np.ndarray = spk_emb.to(dtype=torch.float16, device="cpu").numpy()
s = b14.encode_to_string(
lzma.compress(
arr.tobytes(),
format=lzma.FORMAT_RAW,
filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
),
)
del arr
return s

@torch.inference_mode()
def _vocos_decode(self, spec: torch.Tensor) -> np.ndarray:
if "mps" in str(self.device):
Expand Down
18 changes: 13 additions & 5 deletions cover-pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
import threading
from uilib.cfg import WEB_ADDRESS, SPEAKER_DIR, LOGS_DIR, WAVS_DIR, MODEL_DIR, ROOT_DIR
from uilib import utils,VERSION
from ChatTTS.utils.gpu_utils import select_device
from ChatTTS.utils import select_device
from uilib.utils import is_chinese_os,modelscope_status
merge_size=int(os.getenv('merge_size',10))
env_lang=os.getenv('lang','')
Expand All @@ -70,18 +70,26 @@
else:
is_cn=is_chinese_os()

CHATTTS_DIR= MODEL_DIR+'/pzc163/chatTTS'


chat = ChatTTS.Chat()
device=os.getenv('device','default')
chat.load(source="custom",custom_path=CHATTTS_DIR, device=None if device=='default' else device,compile=True if os.getenv('compile','true').lower()!='false' else False)
device_str=os.getenv('device','default')

if device_str in ['default','mps']:
device=select_device(min_memory=2047,experimental=True if device_str=='mps' else False)
elif device_str =='cuda':
device=select_device(min_memory=2047)
elif device_str == 'cpu':
device = torch.device("cpu")


chat.load(source="custom",custom_path=ROOT_DIR, device=device,compile=True if os.getenv('compile','true').lower()!='false' else False)
n=0
for it in os.listdir('./speaker'):
if it.startswith('seed_') and not it.endswith('_emb-covert.pt'):
print(f'开始转换 {it}')
n+=1
rand_spk=torch.load(f'./speaker/{it}', map_location=select_device(4096) if device=='default' else torch.device(device))
rand_spk=torch.load(f'./speaker/{it}', map_location=device)

torch.save( chat._encode_spk_emb(rand_spk) ,f"{SPEAKER_DIR}/{it.replace('.pt','-covert.pt')}")
if n==0:
Expand Down

0 comments on commit 99889a0

Please sign in to comment.