diff --git a/bert/bert_models.json b/bert/bert_models.json index 78721c363..cd5906d11 100644 --- a/bert/bert_models.json +++ b/bert/bert_models.json @@ -1,6 +1,6 @@ { - "deberta-v2-large-japanese": { - "repo_id": "ku-nlp/deberta-v2-large-japanese", + "deberta-v2-large-japanese-char-wwm": { + "repo_id": "ku-nlp/deberta-v2-large-japanese-char-wwm", "files": ["pytorch_model.bin"] }, "chinese-roberta-wwm-ext-large": { diff --git a/bert_gen.py b/bert_gen.py index 20e2cac5b..f06705426 100644 --- a/bert_gen.py +++ b/bert_gen.py @@ -3,7 +3,7 @@ import commons import utils from tqdm import tqdm -from text import check_bert_models, cleaned_text_to_sequence, get_bert +from text import cleaned_text_to_sequence, get_bert import argparse import torch.multiprocessing as mp from config import config @@ -57,7 +57,6 @@ def process_line(line): args, _ = parser.parse_known_args() config_path = args.config hps = utils.get_hparams_from_file(config_path) - check_bert_models() lines = [] with open(hps.data.training_files, encoding="utf-8") as f: lines.extend(f.readlines()) diff --git a/emo_gen.py b/emo_gen.py index 0fb36d100..fdaa3fb0a 100644 --- a/emo_gen.py +++ b/emo_gen.py @@ -1,19 +1,21 @@ +import argparse +import os +from pathlib import Path + +import librosa +import numpy as np import torch import torch.nn as nn -from torch.utils.data import Dataset -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm from transformers import Wav2Vec2Processor from transformers.models.wav2vec2.modeling_wav2vec2 import ( Wav2Vec2Model, Wav2Vec2PreTrainedModel, ) -import librosa -import numpy as np -import argparse -from config import config + import utils -import os -from tqdm import tqdm +from config import config class RegressionHead(nn.Module): @@ -78,11 +80,6 @@ def __getitem__(self, idx): return torch.from_numpy(processed_data) -model_name = "./emotional/wav2vec2-large-robust-12-ft-emotion-msp-dim" -processor = Wav2Vec2Processor.from_pretrained(model_name) -model = EmotionModel.from_pretrained(model_name) - - def process_func( x: np.ndarray, sampling_rate: int, @@ -135,16 +132,12 @@ def get_emo(path): device = config.bert_gen_config.device model_name = "./emotional/wav2vec2-large-robust-12-ft-emotion-msp-dim" - processor = ( - Wav2Vec2Processor.from_pretrained(model_name) - if processor is None - else processor - ) - model = ( - EmotionModel.from_pretrained(model_name).to(device) - if model is None - else model.to(device) - ) + REPO_ID = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim" + if not Path(model_name).joinpath("pytorch_model.bin").exists(): + utils.download_emo_models(config.mirror, model_name, REPO_ID) + + processor = Wav2Vec2Processor.from_pretrained(model_name) + model = EmotionModel.from_pretrained(model_name).to(device) lines = [] with open(hps.data.training_files, encoding="utf-8") as f: diff --git a/text/__init__.py b/text/__init__.py index 54cbef15f..45592e0a6 100644 --- a/text/__init__.py +++ b/text/__init__.py @@ -46,3 +46,6 @@ def check_bert_models(): for k, v in models.items(): local_path = Path("./bert").joinpath(k) _check_bert(v["repo_id"], v["files"], local_path) + + +check_bert_models() diff --git a/train_ms.py b/train_ms.py index 4a2abc49b..659669f19 100644 --- a/train_ms.py +++ b/train_ms.py @@ -130,7 +130,7 @@ def run(): collate_fn = TextAudioSpeakerCollate() train_loader = DataLoader( train_dataset, - num_workers=config.train_ms_config.num_workers, + num_workers=min(config.train_ms_config.num_workers, os.cpu_count() - 1), shuffle=False, pin_memory=True, collate_fn=collate_fn, diff --git a/utils.py b/utils.py index 12ad286ad..7c1440593 100644 --- a/utils.py +++ b/utils.py @@ -16,6 +16,24 @@ logger = logging.getLogger(__name__) +def download_emo_models(mirror, repo_id, model_name): + if mirror == "openi": + import openi + + openi.model.download_model( + "Stardust_minus/Bert-VITS2", + repo_id.split("/")[-1], + "./emotional", + ) + else: + hf_hub_download( + repo_id, + "pytorch_model.bin", + local_dir=model_name, + local_dir_use_symlinks=False, + ) + + def download_checkpoint( dir_path, repo_config, token=None, regex="G_*.pth", mirror="openi" ):