Skip to content

Commit

Permalink
Auto download update, optimize dataloader num_workers (fishaudio#195)
Browse files Browse the repository at this point in the history
* update bert

* auto download emo

* fix typo

* fix typo

* fix bert download

* optimize code format

* remove unsued import

* fix a bug

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Isotr0py and pre-commit-ci[bot] authored Nov 26, 2023
1 parent 15babcd commit dec3fc0
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 28 deletions.
4 changes: 2 additions & 2 deletions bert/bert_models.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"deberta-v2-large-japanese": {
"repo_id": "ku-nlp/deberta-v2-large-japanese",
"deberta-v2-large-japanese-char-wwm": {
"repo_id": "ku-nlp/deberta-v2-large-japanese-char-wwm",
"files": ["pytorch_model.bin"]
},
"chinese-roberta-wwm-ext-large": {
Expand Down
3 changes: 1 addition & 2 deletions bert_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import commons
import utils
from tqdm import tqdm
from text import check_bert_models, cleaned_text_to_sequence, get_bert
from text import cleaned_text_to_sequence, get_bert
import argparse
import torch.multiprocessing as mp
from config import config
Expand Down Expand Up @@ -57,7 +57,6 @@ def process_line(line):
args, _ = parser.parse_known_args()
config_path = args.config
hps = utils.get_hparams_from_file(config_path)
check_bert_models()
lines = []
with open(hps.data.training_files, encoding="utf-8") as f:
lines.extend(f.readlines())
Expand Down
39 changes: 16 additions & 23 deletions emo_gen.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import argparse
import os
from pathlib import Path

import librosa
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import Wav2Vec2Processor
from transformers.models.wav2vec2.modeling_wav2vec2 import (
Wav2Vec2Model,
Wav2Vec2PreTrainedModel,
)
import librosa
import numpy as np
import argparse
from config import config

import utils
import os
from tqdm import tqdm
from config import config


class RegressionHead(nn.Module):
Expand Down Expand Up @@ -78,11 +80,6 @@ def __getitem__(self, idx):
return torch.from_numpy(processed_data)


model_name = "./emotional/wav2vec2-large-robust-12-ft-emotion-msp-dim"
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = EmotionModel.from_pretrained(model_name)


def process_func(
x: np.ndarray,
sampling_rate: int,
Expand Down Expand Up @@ -135,16 +132,12 @@ def get_emo(path):
device = config.bert_gen_config.device

model_name = "./emotional/wav2vec2-large-robust-12-ft-emotion-msp-dim"
processor = (
Wav2Vec2Processor.from_pretrained(model_name)
if processor is None
else processor
)
model = (
EmotionModel.from_pretrained(model_name).to(device)
if model is None
else model.to(device)
)
REPO_ID = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
if not Path(model_name).joinpath("pytorch_model.bin").exists():
utils.download_emo_models(config.mirror, model_name, REPO_ID)

processor = Wav2Vec2Processor.from_pretrained(model_name)
model = EmotionModel.from_pretrained(model_name).to(device)

lines = []
with open(hps.data.training_files, encoding="utf-8") as f:
Expand Down
3 changes: 3 additions & 0 deletions text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,6 @@ def check_bert_models():
for k, v in models.items():
local_path = Path("./bert").joinpath(k)
_check_bert(v["repo_id"], v["files"], local_path)


check_bert_models()
2 changes: 1 addition & 1 deletion train_ms.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def run():
collate_fn = TextAudioSpeakerCollate()
train_loader = DataLoader(
train_dataset,
num_workers=config.train_ms_config.num_workers,
num_workers=min(config.train_ms_config.num_workers, os.cpu_count() - 1),
shuffle=False,
pin_memory=True,
collate_fn=collate_fn,
Expand Down
18 changes: 18 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,24 @@
logger = logging.getLogger(__name__)


def download_emo_models(mirror, repo_id, model_name):
if mirror == "openi":
import openi

openi.model.download_model(
"Stardust_minus/Bert-VITS2",
repo_id.split("/")[-1],
"./emotional",
)
else:
hf_hub_download(
repo_id,
"pytorch_model.bin",
local_dir=model_name,
local_dir_use_symlinks=False,
)


def download_checkpoint(
dir_path, repo_config, token=None, regex="G_*.pth", mirror="openi"
):
Expand Down

0 comments on commit dec3fc0

Please sign in to comment.